Skip to content

Commit b83ad19

Browse files
committed
minimal testing for A-SOC
1 parent da6a6c7 commit b83ad19

File tree

4 files changed

+643
-34
lines changed

4 files changed

+643
-34
lines changed

orthogonium/layers/conv/adaptiveSOC/fast_skew_ortho_conv.py

Lines changed: 7 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -138,14 +138,17 @@ def attach_soc_weight(
138138
weight_name (str): name of the weight
139139
kernel_shape (tuple): shape of the kernel (out_channels, in_channels/groups, kernel_size, kernel_size)
140140
groups (int): number of groups
141-
bjorck_params (BjorckParams, optional): parameters of the Bjorck orthogonalization. Defaults to BjorckParams().
141+
exp_params (ExpParams): parameters for the exponential algorithm.
142142
143143
Returns:
144144
torch.Tensor: a handle to the attached weight
145145
"""
146146
out_channels, in_channels, kernel_size, k2 = kernel_shape
147147
in_channels *= groups # compute the real number of input channels
148-
assert kernel_size == k2, "only square kernels are supported for the moment"
148+
assert (
149+
kernel_size == k2
150+
), "only square kernels are supported (to compute skew symmetric kernels)"
151+
assert kernel_size % 2 == 1, "kernel size must be odd"
149152
max_channels = max(in_channels, out_channels)
150153
layer.register_parameter(
151154
weight_name,
@@ -250,11 +253,6 @@ def __init__(
250253
groups,
251254
exp_params=exp_params,
252255
)
253-
if bias:
254-
self.bias = nn.Parameter(torch.Tensor(out_channels))
255-
nn.init.zeros_(self.bias)
256-
else:
257-
self.register_parameter("bias", None)
258256

259257
def singular_values(self):
260258
"""Compute the singular values of the convolutional layer using the FFT+SVD method.
@@ -363,12 +361,6 @@ def __init__(
363361
exp_params=exp_params,
364362
)
365363

366-
if bias:
367-
self.bias = nn.Parameter(torch.Tensor(out_channels))
368-
nn.init.zeros_(self.bias)
369-
else:
370-
self.register_parameter("bias", None)
371-
372364
def singular_values(self):
373365
if self.padding_mode != "circular":
374366
print(
@@ -383,8 +375,8 @@ def singular_values(self):
383375
self.groups,
384376
self.in_channels // self.groups,
385377
self.out_channels // self.groups,
386-
self.kernel_size,
387-
self.kernel_size,
378+
self.weight.shape[-2],
379+
self.weight.shape[-1],
388380
)
389381
.numpy(),
390382
self._input_shape,

orthogonium/layers/conv/adaptiveSOC/ortho_conv.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ def AdaptiveSOCConv2d(
4040
)
4141
if kernel_size == stride:
4242
convclass = RKOConv2d
43-
elif (stride == 1) or (in_channels >= out_channels):
43+
elif stride == 1:
4444
convclass = FastSOC
4545
else:
4646
convclass = SOCRkoConv2d

orthogonium/layers/conv/adaptiveSOC/soc_x_rko_conv.py

Lines changed: 30 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -75,14 +75,26 @@ def __init__(
7575
self.groups = groups
7676
self.intermediate_channels = max(in_channels, out_channels // stride**2)
7777
del self.weight
78+
int_kernel_size = kernel_size - (stride - 1)
79+
if int_kernel_size % 2 == 0:
80+
if int_kernel_size <= 2:
81+
int_kernel_size += 1
82+
else:
83+
int_kernel_size -= 1
84+
# warn user that kernel size changed
85+
warnings.warn(
86+
f"kernel size changed from {kernel_size} to {int_kernel_size} "
87+
f"as even kernel size is not supported for SOC.",
88+
RuntimeWarning,
89+
)
7890
attach_soc_weight(
7991
self,
8092
"weight_1",
8193
(
8294
self.intermediate_channels,
8395
in_channels // groups,
84-
kernel_size - (stride - 1),
85-
kernel_size - (stride - 1),
96+
int_kernel_size,
97+
int_kernel_size,
8698
),
8799
groups,
88100
exp_params=exp_params,
@@ -96,12 +108,6 @@ def __init__(
96108
ortho_params=ortho_params,
97109
)
98110

99-
if bias:
100-
self.bias = nn.Parameter(torch.Tensor(out_channels))
101-
nn.init.zeros_(self.bias)
102-
else:
103-
self.register_parameter("bias", None)
104-
105111
@property
106112
def weight(self):
107113
if self.training:
@@ -237,14 +243,26 @@ def __init__(
237243
# RuntimeWarning,
238244
# )
239245
del self.weight
246+
int_kernel_size = kernel_size - (stride - 1)
247+
if int_kernel_size % 2 == 0:
248+
if int_kernel_size <= 2:
249+
int_kernel_size += 1
250+
else:
251+
int_kernel_size -= 1
252+
# warn user that kernel size changed
253+
warnings.warn(
254+
f"kernel size changed from {kernel_size} to {int_kernel_size} "
255+
f"as even kernel size is not supported for SOC.",
256+
RuntimeWarning,
257+
)
240258
attach_soc_weight(
241259
self,
242260
"weight_1",
243261
(
244262
self.intermediate_channels,
245263
out_channels // groups,
246-
kernel_size - (stride - 1),
247-
kernel_size - (stride - 1),
264+
int_kernel_size,
265+
int_kernel_size,
248266
),
249267
groups,
250268
exp_params=exp_params,
@@ -258,12 +276,6 @@ def __init__(
258276
ortho_params=ortho_params,
259277
)
260278

261-
if bias:
262-
self.bias = nn.Parameter(torch.Tensor(out_channels))
263-
nn.init.zeros_(self.bias)
264-
else:
265-
self.register_parameter("bias", None)
266-
267279
def singular_values(self):
268280
if self.padding_mode != "circular":
269281
print(
@@ -278,8 +290,8 @@ def singular_values(self):
278290
self.groups,
279291
self.intermediate_channels // self.groups,
280292
self.out_channels // self.groups,
281-
self.kernel_size,
282-
self.kernel_size,
293+
self.weight_1.shape[-2],
294+
self.weight_1.shape[-1],
283295
)
284296
.numpy(),
285297
self._input_shape,

0 commit comments

Comments
 (0)