Skip to content

Commit f6398f8

Browse files
authored
fix b class on GPU (#2167)
1 parent 714a7ef commit f6398f8

File tree

12 files changed

+326
-45
lines changed

12 files changed

+326
-45
lines changed

mindnlp/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@
3838
mindspore.set_device(os.environ.get('DEVICE_TARGET'))
3939

4040
# for different ascend devices
41-
if platform.system().lower() == 'linux':
41+
if platform.system().lower() == 'linux' and mindspore.get_context('device_target') == 'Ascend':
4242
SOC = MSContext.get_instance().get_ascend_soc_version()
4343
# enable vmm since only vmm can release device memory when del tensor.
4444
if SOC != 'ascend310b':

mindnlp/core/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@
3636
preserve_format = None
3737
legacy_contiguous_format = None
3838
channels_last_3d = None
39+
channels_last = None
3940
memory_format = None
4041

4142
inf = float("inf")

mindnlp/core/_apis/cpu.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1221,3 +1221,12 @@ def logsumexp(input, dim, keepdim=False):
12211221

12221222
def bernoulli(input, generator):
12231223
return legacy.bernoulli(input, seed, offset)
1224+
1225+
def right_shift(input, other):
1226+
return legacy.right_shift(input, other)
1227+
1228+
def histc(input, bins=100, min=0, max=0):
1229+
return legacy.histogram(input, bins, float(min), float(max))
1230+
1231+
def search_sorted(sorted_sequence, values, sorter, dtype, right):
1232+
return legacy.search_sorted(sorted_sequence, values, sorter, dtype, right)

mindnlp/core/_apis/gpu.py

Lines changed: 37 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import mindspore
55
from mindspore._c_expression import _empty_instance
66
from mindnlp import core
7-
from .._op_prim.cpu import legacy
7+
from .._op_prim.gpu import legacy
88

99
try:
1010
from mindspore._c_expression import TensorPy as Tensor_
@@ -34,6 +34,8 @@ def fill_scalar(size, fill_value, dtype):
3434
return legacy.cast(legacy.fill_v2(size, mindspore.Tensor(fill_value)), dtype)
3535

3636
def fill_tensor(size, fill_value, dtype):
37+
if dtype is None:
38+
return legacy.fill_v2(size, mindspore.Tensor(fill_value))
3739
return legacy.cast(legacy.fill_v2(size, fill_value), dtype)
3840

3941
def zeros_like(input, dtype):
@@ -123,6 +125,9 @@ def div(input, other):
123125
return legacy.div(input, other)
124126

125127
def mul(input, other):
128+
if input.dtype == core.bool:
129+
if isinstance(other, bool) or (not isinstance(other, numbers.Number) and other.dtype == core.bool):
130+
return bitwise_and_scalar(input, other)
126131
return legacy.mul(input, other)
127132

128133
def reduce_all(input, axis, keepdims):
@@ -253,6 +258,11 @@ def less(input, other):
253258
return legacy.less(input, other)
254259

255260
def select(condition, x, y):
261+
if isinstance(x, numbers.Number) or x.ndim == 0:
262+
x = fill_scalar(condition.shape, x, None)
263+
if isinstance(y, numbers.Number) or y.ndim == 0:
264+
y = fill_scalar(condition.shape, y, None)
265+
256266
return legacy.select(condition, x, y)
257267

258268
def round(input, decimals):
@@ -317,16 +327,15 @@ def ones_like(input, dtype):
317327
return legacy.ones_like(input)
318328

319329
def embedding(input, weight, padding_idx, max_norm, norm_type, scale_grad_by_freq):
320-
return cast(legacy.gather(weight, input, 0, 0), weight.dtype)
330+
return legacy.gather(weight, input, 0, 0)
321331

322332
def linspace(start, end, steps, dtype):
323333
start = float(start)
324334
end = float(end)
325335
return legacy.lin_space(mindspore.Tensor(start), mindspore.Tensor(end), steps)
326336

327337
def masked_fill(input, mask, value):
328-
if input.dtype.is_floating_point and isinstance(value, numbers.Number):
329-
value = float(value)
338+
value = fill_scalar((), value, input.dtype)
330339
return legacy.masked_fill(input, mask, value)
331340

332341
def sum(input, dim, keepdim, dtype):
@@ -388,9 +397,14 @@ def layer_norm(input, normalized_shape, weight, bias, eps=1e-5):
388397
return legacy.layer_norm(input, weight, bias, begin_axis, begin_axis, eps)
389398

390399
def argmin_with_value(input, axis, keep_dims):
400+
if axis is None:
401+
axis = -1
391402
return legacy.arg_min_with_value(input, axis, keep_dims)
392403

393404
def argmax_with_value(input, axis, keep_dims):
405+
if axis is None:
406+
axis = -1
407+
394408
return legacy.arg_max_with_value(input, axis, keep_dims)
395409

396410
def silu(input):
@@ -425,9 +439,13 @@ def eye(n, m, dtype):
425439
return legacy.eye(n, m, dtype)
426440

427441
def argmax(input, axis, keep_dims):
442+
if axis is None:
443+
axis = -1
428444
return legacy.arg_max_with_value(input, axis, keep_dims)[0]
429445

430446
def argmin(input, axis, keep_dims):
447+
if axis is None:
448+
axis = -1
431449
return legacy.arg_min_with_value(input, axis, keep_dims)[0]
432450

433451
def exp(input):
@@ -489,18 +507,7 @@ def scatter(input, dim, index, src):
489507
return legacy.tensor_scatter_elements(input, index, src, dim, "none")
490508

491509
def batch_norm(input, weight, bias, running_mean=None, runnning_var=None, training=False, momentum=0.1, epsilon=1e-5):
492-
input_ndim = input.ndim
493-
if input_ndim == 2:
494-
return legacy.batch_norm(input, weight, bias, running_mean, runnning_var, training, epsilon, momentum, 'NCHW')
495-
else:
496-
input = transpose_view(input, 1, -1)
497-
input_shape = input.shape
498-
input = reshape(input, (-1, input.shape[-1]))
499-
outs = legacy.batch_norm(input, weight, bias, running_mean, runnning_var, training, epsilon, momentum, 'NCHW')
500-
out = reshape(outs[0], (*input_shape[:-1], -1))
501-
out = transpose_view(out, 1, -1)
502-
503-
return out, outs[1], outs[2]
510+
return legacy.batch_norm(input, weight, bias, running_mean, runnning_var, training, epsilon, momentum, 'NCHW')
504511

505512
def tanh(input):
506513
return legacy.tanh(input)
@@ -797,25 +804,22 @@ def max_pool2d(input, kernel_size, stride=1, padding=0, dilation=1, ceil_mode=Fa
797804
return out
798805

799806
def baddbmm(input, batch1, batch2, alpha=1, beta=1):
800-
return add(mul(beta, input), mul(alpha, bmm(batch1, batch2)))
807+
return add(mul(input, beta), mul(bmm(batch1, batch2), alpha))
801808

802809
def softplus(input, beta=1, threshold=20):
803810
return legacy.softplus(input)
804811

805812
def gather_nd(input, indices):
806813
return legacy.gather_nd(input, indices)
807814

808-
def unique_consecutive(input, return_inverse, return_counts, dim):
809-
return legacy.unique_consecutive(input, return_inverse, return_counts, dim)
810-
811815
def meshgrid(input, lambd):
812816
return legacy.meshgrid(input, lambd)
813817

814818
def addcmul(input, tensor1, tensor2, value=1.0):
815819
return legacy.addcmul(input, tensor1, tensor2, mindspore.Tensor(value))
816820

817821
def addmm(input, mat1, mat2, alpha=1.0, beta=1.0):
818-
return add(mul(beta, input), mul(alpha, bmm(mat1, mat2)))
822+
return add(mul(input, beta), mul(bmm(mat1, mat2), alpha))
819823

820824
def im2col(input, kernel_size, dilation=1, padding=0, stride=1):
821825
out = legacy.im2_col(input, kernel_size, stride, dilation, padding)
@@ -1101,6 +1105,8 @@ def bernoulli(input, generator):
11011105
return legacy.bernoulli(input, seed, offset)
11021106

11031107
def arange(start, end, step, dtype):
1108+
if dtype is not None:
1109+
return cast(legacy.range(start, end, step, 100000), dtype)
11041110
return legacy.range(start, end, step, 100000)
11051111

11061112
def inplace_fill_scalar(input, value):
@@ -1121,3 +1127,13 @@ def inplace_uniform(input, from_, to_, generator_):
11211127
mindspore.tensor(from_, dtype=mindspore.int32),
11221128
mindspore.tensor(to_, dtype=mindspore.int32), 0, 0)
11231129
return input.assign_value(value)
1130+
1131+
def right_shift(input, other):
1132+
return legacy.right_shift(input, other)
1133+
1134+
def inplace_fill_tensor(input, value):
1135+
input.assign_value(fill_tensor(input.shape, value, None))
1136+
return input
1137+
1138+
def search_sorted(sorted_sequence, values, sorter, dtype, right):
1139+
return legacy.search_sorted(sorted_sequence, values, sorter, dtype, right)

mindnlp/core/_apis/npu.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1594,3 +1594,13 @@ def bernoulli(input, generator):
15941594
def multinomial(input, num_samples, replacement, generator):
15951595
seed, offset = generator._step(12) # pylint: disable=protected-access
15961596
return pyboost.multinomial_ext_op(input, num_samples, replacement, seed, offset)
1597+
1598+
def right_shift(input, other):
1599+
if use_pyboost():
1600+
return pyboost.right_shift_op(input, other)
1601+
return legacy.right_shift(input, other)
1602+
1603+
def histc(input, bins=100, min=0, max=0):
1604+
if use_pyboost():
1605+
return pyboost.histc_ext_op(input, bins, float(min), float(max))
1606+
return legacy.histogram(input, bins, float(min), float(max))

mindnlp/core/_tensor.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -110,6 +110,7 @@ def __init__(self, *args, **kwargs):
110110

111111
Tensor.__init__ = __init__
112112
origin_setitem = Tensor.__setitem__
113+
origin_is_contiguous = Tensor.is_contiguous
113114
Tensor._requires_grad = False
114115

115116
def tensor(data, *, dtype=None, device=None, requires_grad=False):
@@ -1253,7 +1254,8 @@ def hardshrink(self, lambd=0.5):
12531254

12541255

12551256
# Tensor.histc
1256-
1257+
def histc(self, bins=100, min=0, max=0):
1258+
return ops.histc(self, bins, min, max)
12571259

12581260
# Tensor.histogram
12591261

@@ -1364,8 +1366,8 @@ def isnan(self):
13641366
return ops.isnan(self)
13651367

13661368
# Tensor.is_contiguous
1367-
# def is_contiguous(self):
1368-
# return self.is_contiguous()
1369+
def is_contiguous(self, memory_format=None):
1370+
return origin_is_contiguous(self)
13691371

13701372
# Tensor.is_complex
13711373
def is_complex(self):

mindnlp/core/cuda/__init__.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -60,8 +60,19 @@ def __exit__(self, type: Any, value: Any, traceback: Any):
6060
def is_bf16_supported():
6161
return False
6262

63-
def mem_get_info(index):
64-
return (1024, 1024)
63+
def mem_get_info(device=None):
64+
if not isinstance(device, int):
65+
device = mindspore.context.get_context("device_id")
66+
67+
res = mindspore.hal.get_device_properties(device)
68+
return (res.total_memory, res.total_memory)
69+
70+
def get_device_capability(device=None):
71+
if not isinstance(device, int):
72+
device = mindspore.context.get_context("device_id")
73+
74+
res = mindspore.hal.get_device_properties(device)
75+
return (res.major, res.minor)
6576

6677
def memory_reserved(device=None):
6778
return ms_memory_reserved()

mindnlp/core/nn/functional.py

Lines changed: 46 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -274,7 +274,7 @@ def pad(input, pad, mode='constant', value=None):
274274
if isinstance(pad, tuple):
275275
pad = tuple(p if isinstance(p, int) else p.item() for p in pad)
276276

277-
if input.device.type in ['cpu', 'meta'] or ON_A1:
277+
if input.device.type in ['cpu', 'meta', 'cuda'] or ON_A1:
278278
new_pad = ()
279279
for idx, pad_v in enumerate(pad):
280280
if not isinstance(pad_v, int):
@@ -301,6 +301,8 @@ def pad(input, pad, mode='constant', value=None):
301301
value = bool(value)
302302
elif input.dtype in [core.int32, core.int64]:
303303
value = int(value)
304+
if input.device.type == 'cuda' and len(new_pad) == 8:
305+
return execute('pad_v3', input, new_pad[:-2], mode, value)
304306
return execute('pad_v3', input, new_pad, mode, value)
305307
out = input
306308
if (isinstance(pad, tuple) and not pad):
@@ -324,9 +326,9 @@ def pad(input, pad, mode='constant', value=None):
324326
return out
325327

326328
def nll_loss(input, target, weight=None, ignore_index=-100, reduction='mean'):
327-
# if input.device.type == 'npu':
328-
return _nllloss_nd(input, target, weight, ignore_index, reduction)
329-
# return _inner_nll_loss(input, target, weight, ignore_index, reduction)
329+
if input.device.type in ['npu', 'cpu']:
330+
return _nllloss_nd(input, target, weight, ignore_index, reduction)
331+
return _inner_nll_loss(input, target, weight, ignore_index, reduction)
330332

331333
def _inner_nll_loss(inputs, target, weight=None, ignore_index=-100, reduction='mean', label_smoothing=0.0):
332334
ndim = inputs.ndim
@@ -352,7 +354,7 @@ def _inner_nll_loss(inputs, target, weight=None, ignore_index=-100, reduction='m
352354
def _nll_loss(inputs, target, target_dim=-1, weight=None, ignore_index=None, reduction='none', label_smoothing=0.0):
353355
"""nll loss inner function"""
354356
if target.ndim == inputs.ndim - 1:
355-
target = target.expand_dims(target_dim)
357+
target = target.unsqueeze(target_dim)
356358
if ignore_index is not None:
357359
non_pad_mask = core.eq(target, ignore_index)
358360
target = target.masked_fill(non_pad_mask, core.cast(0, target.dtype))
@@ -366,10 +368,10 @@ def _nll_loss(inputs, target, target_dim=-1, weight=None, ignore_index=None, red
366368
weight = weight.view(weight.shape + (1,))
367369
weighted_inputs = inputs * weight
368370
weighted_inputs = weighted_inputs.view(orig_shape)
369-
loss = core.neg(core.gather_d(weighted_inputs, target_dim, target))
371+
loss = core.neg(core.gather(weighted_inputs, target_dim, target))
370372
smooth_loss = core.neg(weighted_inputs.sum(axis=target_dim, keepdims=True))
371373
else:
372-
loss = core.neg(core.gather_d(inputs, target_dim, target))
374+
loss = core.neg(core.gather(inputs, target_dim, target))
373375
smooth_loss = core.neg(inputs.sum(axis=target_dim, keepdims=True))
374376
loss_weights = core.ones_like(loss)
375377

@@ -427,11 +429,42 @@ def _nllloss_nd(input, target, weight=None, ingore_index=-100, reduction='mean')
427429
ret = execute('nllloss_2d', input, target, weight, reduction, ingore_index)[0]
428430
return ret.view(out_size)
429431

432+
433+
def cross_entropy_gpu(input, target, weight=None, ignore_index=-100, reduction='mean', label_smoothing=0.0):
434+
class_dim = 0 if input.ndim == 1 else 1
435+
if target.dtype.is_floating_point:
436+
return _cross_entropy(input, target, class_dim, weight, reduction, label_smoothing)
437+
return nll_loss(log_softmax(input, class_dim), target, weight, ignore_index, reduction)
438+
439+
def _cross_entropy(inputs, target, target_dim, weight=None, reduction='mean', label_smoothing=0.0):
440+
"""cross entropy inner function"""
441+
class_dim = 0 if inputs.ndim == 1 else 1
442+
n_classes = inputs.shape[class_dim]
443+
inputs = log_softmax(inputs, class_dim)
444+
if label_smoothing > 0.0:
445+
target = target * (1 - label_smoothing) + label_smoothing / n_classes
446+
447+
if weight is None:
448+
weight = core.ones_like(inputs)
449+
elif inputs.ndim != 1:
450+
broadcast_shape = [1 for _ in range(inputs.ndim)]
451+
broadcast_shape[1] = weight.shape[0]
452+
weight = weight.reshape(broadcast_shape)
453+
454+
if reduction == 'mean':
455+
return -(inputs * target * weight).sum() / (inputs.nel / n_classes)
456+
if reduction == 'sum':
457+
return -(inputs * target * weight).sum()
458+
return -(inputs * target * weight).sum(class_dim)
459+
460+
430461
def cross_entropy(input, target, weight=None, ignore_index=-100, reduction='mean', label_smoothing=0.0):
431462
if label_smoothing < 0.0 or label_smoothing > 1.0:
432463
raise ValueError(f"For cross_entropy, label_smoothing must in [0, 1]")
433464
if input.ndim == 0 or input.shape[0] == 0:
434465
raise ValueError(f"For cross_entropy, input don't support 0-dim and shape[0].")
466+
if input.device.type == 'cuda':
467+
return cross_entropy_gpu(input, target, weight, ignore_index, reduction, label_smoothing)
435468
class_dim = 0 if input.ndim == 1 else 1
436469
n_classes = input.shape[class_dim]
437470
input = log_softmax(input, class_dim, dtype=input.dtype)
@@ -675,10 +708,10 @@ def interpolate(input, size=None, scale_factor=None, mode='nearest', align_corne
675708
)
676709
if input.dim() == 4 and mode == "bicubic":
677710
assert align_corners is not None
678-
if antialias:
679-
return torch._C._nn._upsample_bicubic2d_aa(
680-
input, output_size, align_corners, scale_factors
681-
)
711+
# if antialias:
712+
# return torch._C._nn._upsample_bicubic2d_aa(
713+
# input, output_size, align_corners, scale_factors
714+
# )
682715
return execute(
683716
'upsample_bicubic2d', input, output_size, scale_factors, align_corners
684717
)
@@ -1146,8 +1179,8 @@ def scaled_dot_product_attention(query, key, value, attn_mask=None, dropout_p=0.
11461179
else:
11471180
attn_bias = attn_mask + attn_bias
11481181

1149-
attn_weight = query.float() @ key.transpose(-2, -1).float() * scale_factor
1150-
attn_weight += attn_bias.float()
1182+
attn_weight = query @ key.transpose(-2, -1) * scale_factor
1183+
attn_weight += attn_bias
11511184
attn_weight = softmax(attn_weight, dim=-1, dtype=core.float32).to(query.dtype)
11521185
attn_weight = dropout(attn_weight, dropout_p, training=True)
11531186
return attn_weight @ value

mindnlp/core/ops/_inner.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,14 @@ def npu_clear_float_status_v2(status):
1616
def all_finite(inputs):
1717
return execute('all_finite', inputs)
1818

19+
def custom_masked_scatter_vec(input, mask, source):
20+
output = input.clone()
21+
output[mask] = source.flatten() # 关键的一行:向量化赋值
22+
return output
23+
1924
def masked_scatter(input, mask, source):
25+
if input.device.type == 'cuda':
26+
return custom_masked_scatter_vec(input, mask, source)
2027
return execute('masked_scatter', input, mask, source)
2128

2229
__all__ = [

0 commit comments

Comments
 (0)