Skip to content

Commit 8bc12c8

Browse files
NarineKfacebook-github-bot
authored andcommitted
Fix failing GPU errors for influential examples (#1081)
Summary: Due to floating point arithmetic inaccuracies and limitations switching to double for test cases and passing dtype to projection matrix so that the the arithmetic error is within accepted range. Pull Request resolved: #1081 Reviewed By: 99warriors Differential Revision: D41919707 Pulled By: NarineK fbshipit-source-id: f8ef65e751ada7c3d3baddb1231b547c3be1823c
1 parent dcb87d3 commit 8bc12c8

File tree

2 files changed

+16
-8
lines changed

2 files changed

+16
-8
lines changed

captum/influence/_core/tracincp_fast_rand_proj.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1380,6 +1380,7 @@ def _set_projections_tracincp_fast_rand_proj(
13801380
1
13811381
] # this is the dimension of the input of the last fully-connected layer
13821382
device = batch_jacobians.device
1383+
dtype = batch_jacobians.dtype
13831384

13841385
# choose projection if needed
13851386
# without projection, the dimension of the intermediate quantities returned
@@ -1409,8 +1410,8 @@ def _set_projections_tracincp_fast_rand_proj(
14091410
)
14101411

14111412
projection_quantities = jacobian_projection.to(
1412-
device
1413-
), layer_input_projection.to(device)
1413+
device=device, dtype=dtype
1414+
), layer_input_projection.to(device=device, dtype=dtype)
14141415

14151416
return projection_quantities
14161417

tests/influence/_utils/common.py

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -190,35 +190,42 @@ def get_random_model_and_data(
190190
BasicLinearNet(in_features, hidden_nodes, out_features)
191191
if not unpack_inputs
192192
else MultLinearNet(in_features, hidden_nodes, out_features, num_inputs)
193-
)
193+
).double()
194194

195195
num_checkpoints = 5
196196

197197
for i in range(num_checkpoints):
198-
net.linear1.weight.data = torch.normal(3, 4, (hidden_nodes, in_features))
199-
net.linear2.weight.data = torch.normal(5, 6, (out_features, hidden_nodes))
198+
net.linear1.weight.data = torch.normal(
199+
3, 4, (hidden_nodes, in_features)
200+
).double()
201+
net.linear2.weight.data = torch.normal(
202+
5, 6, (out_features, hidden_nodes)
203+
).double()
200204
if unpack_inputs:
201205
net.pre.weight.data = torch.normal(
202206
3, 4, (in_features, in_features * num_inputs)
203207
)
208+
if hasattr(net, "pre"):
209+
net.pre.weight.data = net.pre.weight.data.double()
204210
checkpoint_name = "-".join(["checkpoint-reg", str(i + 1) + ".pt"])
205211
net_adjusted = _wrap_model_in_dataparallel(net) if use_gpu else net
206212
torch.save(net_adjusted.state_dict(), os.path.join(tmpdir, checkpoint_name))
207213

208214
num_samples = 50
209215
num_train = 32
210-
all_labels = torch.normal(1, 2, (num_samples, out_features))
216+
all_labels = torch.normal(1, 2, (num_samples, out_features)).double()
211217
train_labels = all_labels[:num_train]
212218
test_labels = all_labels[num_train:]
213219

214220
if unpack_inputs:
215221
all_samples = [
216-
torch.normal(0, 1, (num_samples, in_features)) for _ in range(num_inputs)
222+
torch.normal(0, 1, (num_samples, in_features)).double()
223+
for _ in range(num_inputs)
217224
]
218225
train_samples = [ts[:num_train] for ts in all_samples]
219226
test_samples = [ts[num_train:] for ts in all_samples]
220227
else:
221-
all_samples = torch.normal(0, 1, (num_samples, in_features))
228+
all_samples = torch.normal(0, 1, (num_samples, in_features)).double()
222229
train_samples = all_samples[:num_train]
223230
test_samples = all_samples[num_train:]
224231

0 commit comments

Comments
 (0)