Skip to content
Merged
Show file tree
Hide file tree
Changes from 23 commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
888a37b
Adding KPCovC to docs
rvasav26 May 21, 2025
4d537a1
Changing assertTrue to assertEqual for correctness
rvasav26 May 25, 2025
de226db
Investigating into KPCovC inconsistencies
rvasav26 May 27, 2025
6dada5d
Trying out some things for KPCovC problems
rvasav26 May 29, 2025
315f358
Changing KPCovC's test_precomputed_classification
rvasav26 May 29, 2025
e653076
Continuing KPCovC investigation
rvasav26 Jun 4, 2025
844c16e
Changing _BasePCov and _BaseKPCov to be abstract base classes
rvasav26 Jun 1, 2025
3616619
Cleaning up print statements
rvasav26 Jun 4, 2025
004499a
Merging PCovC update
rvasav26 Jun 4, 2025
c676e10
Removing KPCovC experiment
rvasav26 Jun 4, 2025
7d04666
Trying mixing=1.0 for KPCovC/PCovC match
rvasav26 Jun 5, 2025
115a224
Switching KPCovC back to using SVC
rvasav26 Jun 8, 2025
bed1b4b
Minor edits after cleaning up KPCovC branch
rvasav26 Jun 8, 2025
e93b86f
Checking scaling and LinearSVC match
rvasav26 Jun 10, 2025
69dd1b6
Working on docstrings
rvasav26 Jun 10, 2025
b701e23
Adding example drafts
rvasav26 Jun 12, 2025
9e4e3d8
Switching from KPCovC w/SVC back to KPCovC w/linear classifiers
rvasav26 Jun 17, 2025
07aba83
Finalizing examples
rvasav26 Jun 18, 2025
fe6b0c7
Modifying tests
rvasav26 Jun 18, 2025
623bc1f
Modifying docstrings and minor edits
rvasav26 Jun 18, 2025
7cef97c
Updating CHANGELOG
rvasav26 Jun 18, 2025
e0ecb03
Formatting
rvasav26 Jun 18, 2025
26a246e
More formatting and cleaning
rvasav26 Jun 19, 2025
15579d6
Minor edits
rvasav26 Jun 23, 2025
993b215
CHANGELOG suggestion
rvasav26 Jun 25, 2025
4962500
Example suggestions
rvasav26 Jun 25, 2025
5e619ae
Docstring and other skmatter/decomposition suggestions
rvasav26 Jun 25, 2025
a1a316a
Christian's suggestions and decision_function tests
rvasav26 Jun 26, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion CHANGELOG
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,17 @@ The rules for CHANGELOG file:

Unreleased
----------
- Add ``_BaseKPCov`` class (#254)
- Add ``KernelPCovC`` class that inherits shared functionality from ``_BaseKPCov`` (#254)
- Add ``KernelPCovC`` testing suite and examples (#254)
- Modify ``KernelPCovR`` to inherit shared functionality from ``_BaseKPCov`` (#254)

0.3.0 (2025/06/12)
------------------
- Add ``_BasePCov`` class (#248)
- Add ``PCovC`` class that inherits shared functionality from ``_BasePCov`` (#248)
- Add ``PCovC`` testing suite and examples (#248)
- Modify ``PCovR`` to inherit shared functionality from ``_BasePCov_`` (#248)
- Modify ``PCovR`` to inherit shared functionality from ``_BasePCov`` (#248)
- Update to sklearn >= 1.7.0 and scipy >= 1.15.0 (#239, #257)
- Fixed moved function import from scipy and bump scipy dependency to 1.15.0 (#236)
- Fix rendering issues for `SparseKDE` and `QuickShift` (#236)
Expand Down
16 changes: 16 additions & 0 deletions docs/src/references/decomposition.rst
Original file line number Diff line number Diff line change
Expand Up @@ -54,3 +54,19 @@ Kernel PCovR
.. automethod:: predict
.. automethod:: inverse_transform
.. automethod:: score

.. _KPCovC-api:

Kernel PCovC
------------

.. autoclass:: skmatter.decomposition.KernelPCovC
:show-inheritance:
:special-members:

.. automethod:: fit
.. automethod:: transform
.. automethod:: predict
.. automethod:: inverse_transform
.. automethod:: decision_function
.. automethod:: score
261 changes: 261 additions & 0 deletions examples/pcovc/KPCovC_Comparison.py
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I LOVE this example.

Original file line number Diff line number Diff line change
@@ -0,0 +1,261 @@
#!/usr/bin/env python
# coding: utf-8

"""
Comparing KPCovC with KPCA
======================================
"""
# %%
#

import numpy as np

import matplotlib.pyplot as plt
import matplotlib as mpl
from matplotlib.colors import ListedColormap

from sklearn import datasets
from sklearn.preprocessing import StandardScaler
from sklearn.svm import LinearSVC
from sklearn.decomposition import PCA, KernelPCA
from sklearn.inspection import DecisionBoundaryDisplay
from sklearn.model_selection import train_test_split
from sklearn.linear_model import (
LogisticRegressionCV,
RidgeClassifier,
SGDClassifier,
)

from skmatter.decomposition import PCovC, KernelPCovC

plt.rcParams["scatter.edgecolors"] = "k"
cm_bright = ListedColormap(["#d7191c", "#fdae61", "#a6d96a", "#3a7cdf"])

random_state = 0
n_components = 2

# %%
#
# For this, we will combine two ``sklearn`` datasets from
# :func:`sklearn.datasets.make_moons`.

X1, y1 = datasets.make_moons(n_samples=750, noise=0.10, random_state=random_state)
X2, y2 = datasets.make_moons(n_samples=750, noise=0.10, random_state=random_state)

X2, y2 = X2 + 2, y2 + 2
R = np.array(
[
[np.cos(np.pi / 2), -np.sin(np.pi / 2)],
[np.sin(np.pi / 2), np.cos(np.pi / 2)],
]
)
# rotate second pair of moons
X2 = X2 @ R.T

X = np.vstack([X1, X2])
y = np.concatenate([y1, y2])

# %%
#
# Original Data
# -------------

fig, ax = plt.subplots(figsize=(5.5, 5))
ax.scatter(X[:, 0], X[:, 1], c=y, cmap=cm_bright)
ax.set_title("Original Data")


# %%
#
# Scale Data

X_train, X_test, y_train, y_test = train_test_split(
X, y, test_size=0.25, stratify=y, random_state=random_state
)

scaler = StandardScaler()
X_train_scaled = scaler.fit_transform(X_train)
X_test_scaled = scaler.transform(X_test)

# %%
#
# PCA and PCovC
# -------------

mixing = 0.10
alpha_d = 0.5
alpha_p = 0.4

models = {
PCA(n_components=n_components): "PCA",
PCovC(
n_components=n_components,
random_state=random_state,
mixing=mixing,
classifier=LinearSVC(),
): "PCovC",
}

fig, axs = plt.subplots(1, 2, figsize=(10, 4))

for ax, model in zip(axs, models):
t_train = (
model.fit_transform(X_train_scaled)
if isinstance(model, PCA)
else model.fit_transform(X_train_scaled, y_train)
)
t_test = model.transform(X_test_scaled)

ax.scatter(t_test[:, 0], t_test[:, 1], alpha=alpha_d, cmap=cm_bright, c=y_test)
ax.scatter(t_train[:, 0], t_train[:, 1], cmap=cm_bright, c=y_train)

ax.set_title(models[model])
plt.tight_layout()

# %%
#
# Kernel PCA and Kernel PCovC
# ---------------------------

fig, axs = plt.subplots(1, 2, figsize=(13, 6))

center = True
resolution = 1000

kernel_params = {"kernel": "rbf", "gamma": 2}

models = {
KernelPCA(n_components=n_components, **kernel_params): {
"title": "Kernel PCA",
"eps": 0.1,
},
KernelPCovC(
n_components=n_components,
random_state=random_state,
mixing=mixing,
center=center,
**kernel_params,
): {"title": "Kernel PCovC", "eps": 2},
}

for ax, model in zip(axs, models):
t_train = model.fit_transform(X_train_scaled, y_train)
t_test = model.transform(X_test_scaled)

if isinstance(model, KernelPCA):
t_classifier = LinearSVC(random_state=random_state).fit(t_train, y_train)
score = t_classifier.score(t_test, y_test)
else:
t_classifier = model.classifier_
score = model.score(X_test_scaled, y_test)

DecisionBoundaryDisplay.from_estimator(
estimator=t_classifier,
X=t_test,
ax=ax,
response_method="predict",
cmap=cm_bright,
alpha=alpha_d,
eps=models[model]["eps"],
grid_resolution=resolution,
)
ax.scatter(t_test[:, 0], t_test[:, 1], alpha=alpha_p, cmap=cm_bright, c=y_test)
ax.scatter(t_train[:, 0], t_train[:, 1], cmap=cm_bright, c=y_train)
ax.set_title(models[model]["title"])

ax.text(
0.82,
0.03,
f"Score: {round(score, 3)}",
fontsize=mpl.rcParams["axes.titlesize"],
transform=ax.transAxes,
)
ax.set_xticks([])
ax.set_yticks([])

fig.subplots_adjust(wspace=0.04)
plt.tight_layout()


# %%
#
# Effect of KPCovC Classifier on KPCovC Maps and Decision Boundaries
# ------------------------------------------------------------------------------
#
# Based on the evidence :math:`\mathbf{Z}` generated by the underlying classifier fit
# on a computed kernel :math:`\mathbf{K}` and :math:`\mathbf{Y}`, Kernel PCovC will
# produce varying latent space maps. Hence, the decision boundaries produced by the
# linear classifier fit between :math:`\mathbf{T}` and :math:`\mathbf{Y}` to make
# predictions will also vary.

names = ["Logistic Regression", "Ridge Classifier", "Linear SVC", "SGD Classifier"]

models = {
LogisticRegressionCV(random_state=random_state): {
"kernel_params": {"kernel": "rbf", "gamma": 12},
"title": "Logistic Regression",
},
RidgeClassifier(random_state=random_state): {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could also use RidgeClassifierCV here!

"kernel_params": {"kernel": "rbf", "gamma": 1},
"title": "Ridge Classifier",
"eps": 0.25,
},
LinearSVC(random_state=random_state): {
"kernel_params": {"kernel": "rbf", "gamma": 15},
"title": "Support Vector Classification",
},
SGDClassifier(random_state=random_state): {
"kernel_params": {"kernel": "rbf", "gamma": 15},
"title": "SGD Classifier",
"eps": 10,
},
}

fig, axs = plt.subplots(1, len(models), figsize=(4 * len(models), 4))

for ax, name, model in zip(axs.flat, names, models):
kpcovc = KernelPCovC(
n_components=n_components,
random_state=random_state,
mixing=mixing,
classifier=model,
center=center,
**models[model]["kernel_params"],
)
t_kpcovc_train = kpcovc.fit_transform(X_train_scaled, y_train)
t_kpcovc_test = kpcovc.transform(X_test_scaled)
kpcovc_score = kpcovc.score(X_test_scaled, y_test)

DecisionBoundaryDisplay.from_estimator(
estimator=kpcovc.classifier_,
X=t_kpcovc_test,
ax=ax,
response_method="predict",
cmap=cm_bright,
alpha=alpha_d,
eps=models[model].get("eps", 1),
grid_resolution=resolution,
)

ax.scatter(
t_kpcovc_test[:, 0],
t_kpcovc_test[:, 1],
cmap=cm_bright,
alpha=alpha_p,
c=y_test,
)
ax.scatter(t_kpcovc_train[:, 0], t_kpcovc_train[:, 1], cmap=cm_bright, c=y_train)
ax.text(
0.70,
0.03,
f"Score: {round(kpcovc_score, 3)}",
fontsize=mpl.rcParams["axes.titlesize"],
transform=ax.transAxes,
)

ax.set_title(name)
ax.set_xticks([])
ax.set_yticks([])
fig.subplots_adjust(wspace=0.04)

plt.tight_layout()
Loading