Skip to content

Commit c40d4dc

Browse files
FEAT - Implement SmoothQuantileRegression (#312)
1 parent 1d03cdb commit c40d4dc

File tree

6 files changed

+426
-0
lines changed

6 files changed

+426
-0
lines changed

doc/api.rst

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -105,5 +105,7 @@ Experimental
105105
IterativeReweightedL1
106106
PDCD_WS
107107
Pinball
108+
QuantileHuber
109+
SmoothQuantileRegressor
108110
SqrtQuadratic
109111
SqrtLasso

doc/changes/0.5.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,3 +3,4 @@
33
Version 0.5 (in progress)
44
-------------------------
55
- Add support for fitting an intercept in :ref:`SqrtLasso <skglm.experimental.sqrt_lasso.SqrtLasso>` (PR: :gh:`298`)
6+
- Add experimental :ref:`QuantileHuber <skglm.experimental.quantile_huber.QuantileHuber>` and :ref:`SmoothQuantileRegressor <skglm.experimental.quantile_huber.SmoothQuantileRegressor>` for quantile regression, and an example script (PR: :gh:`312`).

examples/plot_smooth_quantile.py

Lines changed: 132 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,132 @@
1+
"""
2+
================================================================================
3+
Smooth Quantile Regression with QuantileHuber
4+
================================================================================
5+
6+
This example compares sklearn's standard quantile regression with skglm's smooth
7+
approximation. Skglm's quantile regression uses a smooth Huber-like approximation
8+
(quadratic near zero, linear in the tails) to replace the non-differentiable
9+
pinball loss. Progressive smoothing enables efficient gradient-based optimization,
10+
maintaining speed and accuracy also on large-scale, high-dimensional datasets.
11+
"""
12+
13+
# Author: Florian Kozikowski
14+
import numpy as np
15+
import time
16+
import matplotlib.pyplot as plt
17+
18+
from sklearn.datasets import make_regression
19+
from sklearn.linear_model import QuantileRegressor
20+
from skglm.experimental.quantile_huber import QuantileHuber, SmoothQuantileRegressor
21+
22+
# Generate regression data
23+
X, y = make_regression(n_samples=1000, n_features=10, noise=0.1, random_state=0)
24+
tau = 0.8 # 80th percentile
25+
26+
# %%
27+
# Compare standard vs smooth quantile regression
28+
# ----------------------------------------------
29+
# Both methods solve the same problem but with different loss functions.
30+
31+
# Standard quantile regression (sklearn)
32+
start = time.time()
33+
sk_model = QuantileRegressor(quantile=tau, alpha=0.1)
34+
sk_model.fit(X, y)
35+
sk_time = time.time() - start
36+
37+
# Smooth quantile regression (skglm)
38+
start = time.time()
39+
smooth_model = SmoothQuantileRegressor(
40+
quantile=tau,
41+
alpha=0.1,
42+
delta_init=0.5, # Initial smoothing parameter
43+
delta_final=0.01, # Final smoothing (smaller = closer to true quantile)
44+
n_deltas=5 # Number of continuation steps
45+
)
46+
smooth_model.fit(X, y)
47+
smooth_time = time.time() - start
48+
49+
# %%
50+
# Evaluate both methods
51+
# ---------------------
52+
# Coverage: fraction of true values below predictions (should ≈ tau)
53+
# Pinball loss: standard quantile regression evaluation metric
54+
#
55+
# Note: No robust benchmarking conducted yet. The speed advantagous likely only
56+
# shows on large-scale, high-dimensional datasets. The sklearn implementation is
57+
# likely faster on small datasets.
58+
59+
60+
def pinball_loss(residuals, quantile):
61+
return np.mean(residuals * (quantile - (residuals < 0)))
62+
63+
64+
sk_pred = sk_model.predict(X)
65+
smooth_pred = smooth_model.predict(X)
66+
67+
print(f"{'Method':<15} {'Coverage':<10} {'Time (s)':<10} {'Pinball Loss':<12}")
68+
print("-" * 50)
69+
print(f"{'Sklearn':<15} {np.mean(y <= sk_pred):<10.3f} {sk_time:<10.3f} "
70+
f"{pinball_loss(y - sk_pred, tau):<12.4f}")
71+
print(f"{'SmoothQuantile':<15} {np.mean(y <= smooth_pred):<10.3f} {smooth_time:<10.3f} "
72+
f"{pinball_loss(y - smooth_pred, tau):<12.4f}")
73+
74+
# %%
75+
# Visualize the smooth approximation
76+
# ----------------------------------
77+
# The smooth loss approximates the pinball loss but with continuous gradients
78+
79+
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(10, 4))
80+
81+
# Show loss and gradient for different quantile levels
82+
residuals = np.linspace(-3, 3, 500)
83+
delta = 0.5
84+
quantiles = [0.1, 0.5, 0.9]
85+
86+
for tau_val in quantiles:
87+
qh = QuantileHuber(quantile=tau_val, delta=delta)
88+
loss = [qh._loss_sample(r) for r in residuals]
89+
grad = [qh._grad_per_sample(r) for r in residuals]
90+
91+
# Compute pinball loss for each residual
92+
pinball_loss = [r * (tau_val - (r < 0)) for r in residuals]
93+
94+
# Plot smooth loss and pinball loss
95+
ax1.plot(residuals, loss, label=f"τ={tau_val}", linewidth=2)
96+
ax1.plot(residuals, pinball_loss, '--', alpha=0.4, color='gray',
97+
label=f"Pinball τ={tau_val}")
98+
ax2.plot(residuals, grad, label=f"τ={tau_val}", linewidth=2)
99+
100+
# Add vertical lines and shading showing delta boundaries
101+
for ax in [ax1, ax2]:
102+
ax.axvline(-delta, color='gray', linestyle='--', alpha=0.7, linewidth=1.5)
103+
ax.axvline(delta, color='gray', linestyle='--', alpha=0.7, linewidth=1.5)
104+
# Add shading for quadratic region
105+
ax.axvspan(-delta, delta, alpha=0.15, color='gray')
106+
107+
# Add delta labels
108+
ax1.text(-delta, 0.1, '−δ', ha='right', va='bottom', color='gray', fontsize=10)
109+
ax1.text(delta, 0.1, '+δ', ha='left', va='bottom', color='gray', fontsize=10)
110+
111+
ax1.set_title(f"Smooth Quantile Loss (δ={delta})", fontsize=12)
112+
ax1.set_xlabel("Residual")
113+
ax1.set_ylabel("Loss")
114+
ax1.legend(loc='upper left')
115+
ax1.grid(True, alpha=0.3)
116+
117+
ax2.set_title("Gradient (continuous everywhere)", fontsize=12)
118+
ax2.set_xlabel("Residual")
119+
ax2.set_ylabel("Gradient")
120+
ax2.legend(loc='upper left')
121+
ax2.grid(True, alpha=0.3)
122+
123+
plt.tight_layout()
124+
plt.show()
125+
126+
# %% [markdown]
127+
# The left plot shows the asymmetric loss: tau=0.1 penalizes overestimation more,
128+
# while tau=0.9 penalizes underestimation. As delta decreases towards zero, the
129+
# loss function approaches the standard pinball loss.
130+
# The right plot reveals the key advantage: gradients transition smoothly through
131+
# zero, unlike standard quantile regression which has a kink. This smoothing
132+
# enables fast convergence with gradient-based solvers.

skglm/experimental/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,14 @@
22
from .sqrt_lasso import SqrtLasso, SqrtQuadratic
33
from .pdcd_ws import PDCD_WS
44
from .quantile_regression import Pinball
5+
from .quantile_huber import QuantileHuber, SmoothQuantileRegressor
56

67
__all__ = [
78
IterativeReweightedL1,
89
PDCD_WS,
910
Pinball,
11+
QuantileHuber,
12+
SmoothQuantileRegressor,
1013
SqrtQuadratic,
1114
SqrtLasso,
1215
]

0 commit comments

Comments
 (0)