🌐 AI搜索 & 代理 主页
Skip to content

Commit b5f8f1c

Browse files
virchanogrisel
andauthored
ENH add Array API support for d2_pinball_score and d2_absolute_error_score (#31671)
Co-authored-by: Olivier Grisel <olivier.grisel@ensta.org>
1 parent 92fb813 commit b5f8f1c

File tree

5 files changed

+70
-35
lines changed

5 files changed

+70
-35
lines changed

doc/modules/array_api.rst

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -164,8 +164,10 @@ Metrics
164164
- :func:`sklearn.metrics.cluster.calinski_harabasz_score`
165165
- :func:`sklearn.metrics.cohen_kappa_score`
166166
- :func:`sklearn.metrics.confusion_matrix`
167+
- :func:`sklearn.metrics.d2_absolute_error_score`
167168
- :func:`sklearn.metrics.d2_brier_score`
168169
- :func:`sklearn.metrics.d2_log_loss_score`
170+
- :func:`sklearn.metrics.d2_pinball_score`
169171
- :func:`sklearn.metrics.d2_tweedie_score`
170172
- :func:`sklearn.metrics.det_curve`
171173
- :func:`sklearn.metrics.explained_variance_score`
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
- :func:`sklearn.metrics.d2_absolute_error_score` and
2+
:func:`sklearn.metrics.d2_pinball_score` now support array API compatible inputs.
3+
By :user:`Virgil Chan <virchan>`.
Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
- :func:`metrics.d2_pinball_score` and :func:`metrics.d2_absolute_error_score` now
2+
always use the `"averaged_inverted_cdf"` quantile method, both with and
3+
without sample weights. Previously, the `"linear"` quantile method was used only
4+
for the unweighted case leading the surprising discrepancies when comparing the
5+
results with unit weights. Note that all quantile interpolation methods are
6+
asymptotically equivalent in the large sample limit, but this fix can cause score
7+
value changes on small evaluation sets (without weights).
8+
By :user:`Virgil Chan <virchan>`.

sklearn/metrics/_regression.py

Lines changed: 32 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -936,7 +936,7 @@ def median_absolute_error(
936936
return float(_average(output_errors, weights=multioutput, xp=xp))
937937

938938

939-
def _assemble_r2_explained_variance(
939+
def _assemble_fraction_of_explained_deviance(
940940
numerator, denominator, n_outputs, multioutput, force_finite, xp, device
941941
):
942942
"""Common part used by explained variance score and :math:`R^2` score."""
@@ -1121,7 +1121,7 @@ def explained_variance_score(
11211121
(y_true - y_true_avg) ** 2, weights=sample_weight, axis=0, xp=xp
11221122
)
11231123

1124-
return _assemble_r2_explained_variance(
1124+
return _assemble_fraction_of_explained_deviance(
11251125
numerator=numerator,
11261126
denominator=denominator,
11271127
n_outputs=y_true.shape[1],
@@ -1300,7 +1300,7 @@ def r2_score(
13001300
axis=0,
13011301
)
13021302

1303-
return _assemble_r2_explained_variance(
1303+
return _assemble_fraction_of_explained_deviance(
13041304
numerator=numerator,
13051305
denominator=denominator,
13061306
n_outputs=y_true.shape[1],
@@ -1779,9 +1779,9 @@ def d2_pinball_score(
17791779
>>> d2_pinball_score(y_true, y_pred)
17801780
0.5
17811781
>>> d2_pinball_score(y_true, y_pred, alpha=0.9)
1782-
0.772...
1782+
0.666...
17831783
>>> d2_pinball_score(y_true, y_pred, alpha=0.1)
1784-
-1.045...
1784+
-1.999...
17851785
>>> d2_pinball_score(y_true, y_true, alpha=0.1)
17861786
1.0
17871787
@@ -1803,9 +1803,14 @@ def d2_pinball_score(
18031803
>>> grid.best_params_
18041804
{'fit_intercept': True}
18051805
"""
1806-
_, y_true, y_pred, sample_weight, multioutput = _check_reg_targets(
1806+
xp, _, device_ = get_namespace_and_device(
18071807
y_true, y_pred, sample_weight, multioutput
18081808
)
1809+
_, y_true, y_pred, sample_weight, multioutput = (
1810+
_check_reg_targets_with_floating_dtype(
1811+
y_true, y_pred, sample_weight, multioutput, xp=xp
1812+
)
1813+
)
18091814

18101815
if _num_samples(y_pred) < 2:
18111816
msg = "D^2 score is not well-defined with less than two samples."
@@ -1821,16 +1826,18 @@ def d2_pinball_score(
18211826
)
18221827

18231828
if sample_weight is None:
1824-
y_quantile = np.tile(
1825-
np.percentile(y_true, q=alpha * 100, axis=0), (len(y_true), 1)
1826-
)
1827-
else:
1828-
y_quantile = np.tile(
1829-
_weighted_percentile(
1830-
y_true, sample_weight=sample_weight, percentile_rank=alpha * 100
1831-
),
1832-
(len(y_true), 1),
1833-
)
1829+
sample_weight = xp.ones([y_true.shape[0]], dtype=y_true.dtype, device=device_)
1830+
1831+
y_quantile = xp.tile(
1832+
_weighted_percentile(
1833+
y_true,
1834+
sample_weight=sample_weight,
1835+
percentile_rank=alpha * 100,
1836+
average=True,
1837+
xp=xp,
1838+
),
1839+
(y_true.shape[0], 1),
1840+
)
18341841

18351842
denominator = mean_pinball_loss(
18361843
y_true,
@@ -1840,25 +1847,15 @@ def d2_pinball_score(
18401847
multioutput="raw_values",
18411848
)
18421849

1843-
nonzero_numerator = numerator != 0
1844-
nonzero_denominator = denominator != 0
1845-
valid_score = nonzero_numerator & nonzero_denominator
1846-
output_scores = np.ones(y_true.shape[1])
1847-
1848-
output_scores[valid_score] = 1 - (numerator[valid_score] / denominator[valid_score])
1849-
output_scores[nonzero_numerator & ~nonzero_denominator] = 0.0
1850-
1851-
if isinstance(multioutput, str):
1852-
if multioutput == "raw_values":
1853-
# return scores individually
1854-
return output_scores
1855-
else: # multioutput == "uniform_average"
1856-
# passing None as weights to np.average results in uniform mean
1857-
avg_weights = None
1858-
else:
1859-
avg_weights = multioutput
1860-
1861-
return float(np.average(output_scores, weights=avg_weights))
1850+
return _assemble_fraction_of_explained_deviance(
1851+
numerator=numerator,
1852+
denominator=denominator,
1853+
n_outputs=y_true.shape[1],
1854+
multioutput=multioutput,
1855+
force_finite=True,
1856+
xp=xp,
1857+
device=device_,
1858+
)
18621859

18631860

18641861
@validate_params(

sklearn/metrics/tests/test_common.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -148,6 +148,11 @@
148148
"mean_compound_poisson_deviance": partial(mean_tweedie_deviance, power=1.4),
149149
"d2_tweedie_score": partial(d2_tweedie_score, power=1.4),
150150
"d2_pinball_score": d2_pinball_score,
151+
# The default `alpha=0.5` (median) masks differences between quantile methods,
152+
# so we also test `alpha=0.1` and `alpha=0.9` to ensure correctness
153+
# for non-median quantiles.
154+
"d2_pinball_score_01": partial(d2_pinball_score, alpha=0.1),
155+
"d2_pinball_score_09": partial(d2_pinball_score, alpha=0.9),
151156
"d2_absolute_error_score": d2_absolute_error_score,
152157
}
153158

@@ -492,6 +497,8 @@ def precision_recall_curve_padded_thresholds(*args, **kwargs):
492497
"mean_absolute_percentage_error",
493498
"mean_pinball_loss",
494499
"d2_pinball_score",
500+
"d2_pinball_score_01",
501+
"d2_pinball_score_09",
495502
"d2_absolute_error_score",
496503
}
497504

@@ -563,6 +570,8 @@ def precision_recall_curve_padded_thresholds(*args, **kwargs):
563570
"mean_compound_poisson_deviance",
564571
"d2_tweedie_score",
565572
"d2_pinball_score",
573+
"d2_pinball_score_01",
574+
"d2_pinball_score_09",
566575
"d2_absolute_error_score",
567576
"mean_absolute_percentage_error",
568577
}
@@ -2358,6 +2367,22 @@ def check_array_api_metric_pairwise(metric, array_namespace, device, dtype_name)
23582367
check_array_api_regression_metric,
23592368
check_array_api_regression_metric_multioutput,
23602369
],
2370+
d2_absolute_error_score: [
2371+
check_array_api_regression_metric,
2372+
check_array_api_regression_metric_multioutput,
2373+
],
2374+
d2_pinball_score: [
2375+
check_array_api_regression_metric,
2376+
check_array_api_regression_metric_multioutput,
2377+
],
2378+
partial(d2_pinball_score, alpha=0.1): [
2379+
check_array_api_regression_metric,
2380+
check_array_api_regression_metric_multioutput,
2381+
],
2382+
partial(d2_pinball_score, alpha=0.9): [
2383+
check_array_api_regression_metric,
2384+
check_array_api_regression_metric_multioutput,
2385+
],
23612386
d2_tweedie_score: [
23622387
check_array_api_regression_metric,
23632388
],

0 commit comments

Comments
 (0)