Skip to content

Commit 14bceb7

Browse files
authored
fix diffusers pipelines s class ut (#2096)
1 parent 5bc3b2b commit 14bceb7

File tree

4 files changed

+330
-2
lines changed

4 files changed

+330
-2
lines changed

mindnlp/core/_tensor.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,14 @@ def tensor(data, *, dtype=None, device=None, requires_grad=False):
8282
UserWarning("To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than core.tensor(sourceTensor).")
8383
return Tensor(data)
8484

85+
if isinstance(data, list):
86+
new_data = []
87+
for d in data:
88+
if isinstance(d, Tensor):
89+
d = d.item()
90+
new_data.append(d)
91+
data = new_data
92+
8593
if device is None:
8694
device = get_default_device()
8795

@@ -730,6 +738,13 @@ def __iter__(self):
730738
Tensor.__iter__ = __iter__
731739
StubTensor.__iter__ = __iter__
732740

741+
def __float__(self):
742+
out = self.item()
743+
return round(float(out), 5)
744+
745+
Tensor.__float__ = __float__
746+
StubTensor.__float__ = __float__
747+
733748
def _rebuild_from_type_v2(func, new_type, args, state):
734749
ret = func(*args)
735750
return ret

mindnlp/core/linalg/__init__.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
11
from collections import namedtuple
2+
import numpy as np
3+
24
from mindspore import ops, mint
35
from mindspore.ops._primitive_cache import _get_cache_prim
46

@@ -25,4 +27,7 @@ def norm(A, ord=None, dim=None, keepdim=False, *, out=None, dtype=None):
2527
return mint.norm(A, 2 if ord is None else ord, dim, keepdim, dtype=dtype)
2628

2729
def vector_norm(x, ord=2, dim=None, keepdim=False, *, dtype=None, out=None):
28-
return mint.linalg.vector_norm(x, ord, dim, keepdim, dtype=dtype)
30+
return mint.linalg.vector_norm(x, ord, dim, keepdim, dtype=dtype)
31+
32+
def solve(A, B, *, left=True, out=None):
33+
return core.tensor(np.linalg.solve(A.numpy(), B.numpy()))

mindnlp/core/nn/modules/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from .dropout import Dropout, Dropout2d
88
from .activation import *
99
from .conv import Conv3d, Conv2d, Conv1d, ConvTranspose2d, ConvTranspose1d, ConvTranspose3d
10-
from .padding import ZeroPad2d, ConstantPad2d, ConstantPad1d, ConstantPad3d
10+
from .padding import *
1111
from .batchnorm import BatchNorm2d, BatchNorm1d, SyncBatchNorm
1212
from .pooling import AdaptiveAvgPool2d, AvgPool1d, MaxPool2d, MaxPool1d, AdaptiveAvgPool1d, AvgPool2d
1313
from .flatten import Unflatten, Flatten

mindnlp/core/nn/modules/padding.py

Lines changed: 308 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -251,3 +251,311 @@ def __init__(self, padding: _size_6_t) -> None:
251251

252252
def extra_repr(self) -> str:
253253
return f'{self.padding}'
254+
255+
class _ReflectionPadNd(Module):
256+
__constants__ = ["padding"]
257+
padding: Sequence[int]
258+
259+
def forward(self, input: Tensor) -> Tensor:
260+
return F.pad(input, self.padding, "reflect")
261+
262+
def extra_repr(self) -> str:
263+
return f"{self.padding}"
264+
265+
266+
class ReflectionPad1d(_ReflectionPadNd):
267+
r"""Pads the input tensor using the reflection of the input boundary.
268+
269+
For `N`-dimensional padding, use :func:`torch.nn.functional.pad()`.
270+
271+
Args:
272+
padding (int, tuple): the size of the padding. If is `int`, uses the same
273+
padding in all boundaries. If a 2-`tuple`, uses
274+
(:math:`\text{padding\_left}`, :math:`\text{padding\_right}`)
275+
Note that padding size should be less than the corresponding input dimension.
276+
277+
Shape:
278+
- Input: :math:`(C, W_{in})` or :math:`(N, C, W_{in})`.
279+
- Output: :math:`(C, W_{out})` or :math:`(N, C, W_{out})`, where
280+
281+
:math:`W_{out} = W_{in} + \text{padding\_left} + \text{padding\_right}`
282+
283+
Examples::
284+
285+
>>> m = nn.ReflectionPad1d(2)
286+
>>> # xdoctest: +IGNORE_WANT("other tests seem to modify printing styles")
287+
>>> input = torch.arange(8, dtype=torch.float).reshape(1, 2, 4)
288+
>>> input
289+
tensor([[[0., 1., 2., 3.],
290+
[4., 5., 6., 7.]]])
291+
>>> m(input)
292+
tensor([[[2., 1., 0., 1., 2., 3., 2., 1.],
293+
[6., 5., 4., 5., 6., 7., 6., 5.]]])
294+
>>> # using different paddings for different sides
295+
>>> m = nn.ReflectionPad1d((3, 1))
296+
>>> m(input)
297+
tensor([[[3., 2., 1., 0., 1., 2., 3., 2.],
298+
[7., 6., 5., 4., 5., 6., 7., 6.]]])
299+
"""
300+
301+
padding: tuple[int, int]
302+
303+
def __init__(self, padding: _size_2_t) -> None:
304+
super().__init__()
305+
self.padding = _pair(padding)
306+
307+
308+
class ReflectionPad2d(_ReflectionPadNd):
309+
r"""Pads the input tensor using the reflection of the input boundary.
310+
311+
For `N`-dimensional padding, use :func:`torch.nn.functional.pad()`.
312+
313+
Args:
314+
padding (int, tuple): the size of the padding. If is `int`, uses the same
315+
padding in all boundaries. If a 4-`tuple`, uses (:math:`\text{padding\_left}`,
316+
:math:`\text{padding\_right}`, :math:`\text{padding\_top}`, :math:`\text{padding\_bottom}`)
317+
Note that padding size should be less than the corresponding input dimension.
318+
319+
Shape:
320+
- Input: :math:`(N, C, H_{in}, W_{in})` or :math:`(C, H_{in}, W_{in})`.
321+
- Output: :math:`(N, C, H_{out}, W_{out})` or :math:`(C, H_{out}, W_{out})` where
322+
323+
:math:`H_{out} = H_{in} + \text{padding\_top} + \text{padding\_bottom}`
324+
325+
:math:`W_{out} = W_{in} + \text{padding\_left} + \text{padding\_right}`
326+
327+
Examples::
328+
329+
>>> # xdoctest: +IGNORE_WANT("not sure why xdoctest is choking on this")
330+
>>> m = nn.ReflectionPad2d(2)
331+
>>> input = torch.arange(9, dtype=torch.float).reshape(1, 1, 3, 3)
332+
>>> input
333+
tensor([[[[0., 1., 2.],
334+
[3., 4., 5.],
335+
[6., 7., 8.]]]])
336+
>>> m(input)
337+
tensor([[[[8., 7., 6., 7., 8., 7., 6.],
338+
[5., 4., 3., 4., 5., 4., 3.],
339+
[2., 1., 0., 1., 2., 1., 0.],
340+
[5., 4., 3., 4., 5., 4., 3.],
341+
[8., 7., 6., 7., 8., 7., 6.],
342+
[5., 4., 3., 4., 5., 4., 3.],
343+
[2., 1., 0., 1., 2., 1., 0.]]]])
344+
>>> # using different paddings for different sides
345+
>>> m = nn.ReflectionPad2d((1, 1, 2, 0))
346+
>>> m(input)
347+
tensor([[[[7., 6., 7., 8., 7.],
348+
[4., 3., 4., 5., 4.],
349+
[1., 0., 1., 2., 1.],
350+
[4., 3., 4., 5., 4.],
351+
[7., 6., 7., 8., 7.]]]])
352+
"""
353+
354+
padding: tuple[int, int, int, int]
355+
356+
def __init__(self, padding: _size_4_t) -> None:
357+
super().__init__()
358+
self.padding = _quadruple(padding)
359+
360+
361+
class ReflectionPad3d(_ReflectionPadNd):
362+
r"""Pads the input tensor using the reflection of the input boundary.
363+
364+
For `N`-dimensional padding, use :func:`torch.nn.functional.pad()`.
365+
366+
Args:
367+
padding (int, tuple): the size of the padding. If is `int`, uses the same
368+
padding in all boundaries. If a 6-`tuple`, uses
369+
(:math:`\text{padding\_left}`, :math:`\text{padding\_right}`,
370+
:math:`\text{padding\_top}`, :math:`\text{padding\_bottom}`,
371+
:math:`\text{padding\_front}`, :math:`\text{padding\_back}`)
372+
Note that padding size should be less than the corresponding input dimension.
373+
374+
Shape:
375+
- Input: :math:`(N, C, D_{in}, H_{in}, W_{in})` or :math:`(C, D_{in}, H_{in}, W_{in})`.
376+
- Output: :math:`(N, C, D_{out}, H_{out}, W_{out})` or :math:`(C, D_{out}, H_{out}, W_{out})`,
377+
where
378+
379+
:math:`D_{out} = D_{in} + \text{padding\_front} + \text{padding\_back}`
380+
381+
:math:`H_{out} = H_{in} + \text{padding\_top} + \text{padding\_bottom}`
382+
383+
:math:`W_{out} = W_{in} + \text{padding\_left} + \text{padding\_right}`
384+
385+
Examples::
386+
387+
>>> # xdoctest: +IGNORE_WANT("not sure why xdoctest is choking on this")
388+
>>> m = nn.ReflectionPad3d(1)
389+
>>> input = torch.arange(8, dtype=torch.float).reshape(1, 1, 2, 2, 2)
390+
>>> m(input)
391+
tensor([[[[[7., 6., 7., 6.],
392+
[5., 4., 5., 4.],
393+
[7., 6., 7., 6.],
394+
[5., 4., 5., 4.]],
395+
[[3., 2., 3., 2.],
396+
[1., 0., 1., 0.],
397+
[3., 2., 3., 2.],
398+
[1., 0., 1., 0.]],
399+
[[7., 6., 7., 6.],
400+
[5., 4., 5., 4.],
401+
[7., 6., 7., 6.],
402+
[5., 4., 5., 4.]],
403+
[[3., 2., 3., 2.],
404+
[1., 0., 1., 0.],
405+
[3., 2., 3., 2.],
406+
[1., 0., 1., 0.]]]]])
407+
"""
408+
409+
padding: tuple[int, int, int, int, int, int]
410+
411+
def __init__(self, padding: _size_6_t) -> None:
412+
super().__init__()
413+
self.padding = _ntuple(6)(padding)
414+
415+
416+
class _ReplicationPadNd(Module):
417+
__constants__ = ["padding"]
418+
padding: Sequence[int]
419+
420+
def forward(self, input: Tensor) -> Tensor:
421+
return F.pad(input, self.padding, "replicate")
422+
423+
def extra_repr(self) -> str:
424+
return f"{self.padding}"
425+
426+
427+
class ReplicationPad1d(_ReplicationPadNd):
428+
r"""Pads the input tensor using replication of the input boundary.
429+
430+
For `N`-dimensional padding, use :func:`torch.nn.functional.pad()`.
431+
432+
Args:
433+
padding (int, tuple): the size of the padding. If is `int`, uses the same
434+
padding in all boundaries. If a 2-`tuple`, uses
435+
(:math:`\text{padding\_left}`, :math:`\text{padding\_right}`)
436+
Note that the output dimensions must remain positive.
437+
438+
Shape:
439+
- Input: :math:`(C, W_{in})` or :math:`(N, C, W_{in})`.
440+
- Output: :math:`(C, W_{out})` or :math:`(N, C, W_{out})`, where
441+
442+
:math:`W_{out} = W_{in} + \text{padding\_left} + \text{padding\_right}`
443+
444+
Examples::
445+
446+
>>> # xdoctest: +IGNORE_WANT("not sure why xdoctest is choking on this")
447+
>>> m = nn.ReplicationPad1d(2)
448+
>>> input = torch.arange(8, dtype=torch.float).reshape(1, 2, 4)
449+
>>> input
450+
tensor([[[0., 1., 2., 3.],
451+
[4., 5., 6., 7.]]])
452+
>>> m(input)
453+
tensor([[[0., 0., 0., 1., 2., 3., 3., 3.],
454+
[4., 4., 4., 5., 6., 7., 7., 7.]]])
455+
>>> # using different paddings for different sides
456+
>>> m = nn.ReplicationPad1d((3, 1))
457+
>>> m(input)
458+
tensor([[[0., 0., 0., 0., 1., 2., 3., 3.],
459+
[4., 4., 4., 4., 5., 6., 7., 7.]]])
460+
"""
461+
462+
padding: tuple[int, int]
463+
464+
def __init__(self, padding: _size_2_t) -> None:
465+
super().__init__()
466+
self.padding = _pair(padding)
467+
468+
469+
class ReplicationPad2d(_ReplicationPadNd):
470+
r"""Pads the input tensor using replication of the input boundary.
471+
472+
For `N`-dimensional padding, use :func:`torch.nn.functional.pad()`.
473+
474+
Args:
475+
padding (int, tuple): the size of the padding. If is `int`, uses the same
476+
padding in all boundaries. If a 4-`tuple`, uses (:math:`\text{padding\_left}`,
477+
:math:`\text{padding\_right}`, :math:`\text{padding\_top}`, :math:`\text{padding\_bottom}`)
478+
Note that the output dimensions must remain positive.
479+
480+
Shape:
481+
- Input: :math:`(N, C, H_{in}, W_{in})` or :math:`(C, H_{in}, W_{in})`.
482+
- Output: :math:`(N, C, H_{out}, W_{out})` or :math:`(C, H_{out}, W_{out})`, where
483+
484+
:math:`H_{out} = H_{in} + \text{padding\_top} + \text{padding\_bottom}`
485+
486+
:math:`W_{out} = W_{in} + \text{padding\_left} + \text{padding\_right}`
487+
488+
Examples::
489+
490+
>>> m = nn.ReplicationPad2d(2)
491+
>>> # xdoctest: +IGNORE_WANT("non-deterministic")
492+
>>> input = torch.arange(9, dtype=torch.float).reshape(1, 1, 3, 3)
493+
>>> input
494+
tensor([[[[0., 1., 2.],
495+
[3., 4., 5.],
496+
[6., 7., 8.]]]])
497+
>>> m(input)
498+
tensor([[[[0., 0., 0., 1., 2., 2., 2.],
499+
[0., 0., 0., 1., 2., 2., 2.],
500+
[0., 0., 0., 1., 2., 2., 2.],
501+
[3., 3., 3., 4., 5., 5., 5.],
502+
[6., 6., 6., 7., 8., 8., 8.],
503+
[6., 6., 6., 7., 8., 8., 8.],
504+
[6., 6., 6., 7., 8., 8., 8.]]]])
505+
>>> # using different paddings for different sides
506+
>>> m = nn.ReplicationPad2d((1, 1, 2, 0))
507+
>>> m(input)
508+
tensor([[[[0., 0., 1., 2., 2.],
509+
[0., 0., 1., 2., 2.],
510+
[0., 0., 1., 2., 2.],
511+
[3., 3., 4., 5., 5.],
512+
[6., 6., 7., 8., 8.]]]])
513+
"""
514+
515+
padding: tuple[int, int, int, int]
516+
517+
def __init__(self, padding: _size_4_t) -> None:
518+
super().__init__()
519+
self.padding = _quadruple(padding)
520+
521+
522+
class ReplicationPad3d(_ReplicationPadNd):
523+
r"""Pads the input tensor using replication of the input boundary.
524+
525+
For `N`-dimensional padding, use :func:`torch.nn.functional.pad()`.
526+
527+
Args:
528+
padding (int, tuple): the size of the padding. If is `int`, uses the same
529+
padding in all boundaries. If a 6-`tuple`, uses
530+
(:math:`\text{padding\_left}`, :math:`\text{padding\_right}`,
531+
:math:`\text{padding\_top}`, :math:`\text{padding\_bottom}`,
532+
:math:`\text{padding\_front}`, :math:`\text{padding\_back}`)
533+
Note that the output dimensions must remain positive.
534+
535+
Shape:
536+
- Input: :math:`(N, C, D_{in}, H_{in}, W_{in})` or :math:`(C, D_{in}, H_{in}, W_{in})`.
537+
- Output: :math:`(N, C, D_{out}, H_{out}, W_{out})` or :math:`(C, D_{out}, H_{out}, W_{out})`,
538+
where
539+
540+
:math:`D_{out} = D_{in} + \text{padding\_front} + \text{padding\_back}`
541+
542+
:math:`H_{out} = H_{in} + \text{padding\_top} + \text{padding\_bottom}`
543+
544+
:math:`W_{out} = W_{in} + \text{padding\_left} + \text{padding\_right}`
545+
546+
Examples::
547+
548+
>>> # xdoctest: +IGNORE_WANT("non-deterministic")
549+
>>> m = nn.ReplicationPad3d(3)
550+
>>> input = torch.randn(16, 3, 8, 320, 480)
551+
>>> output = m(input)
552+
>>> # using different paddings for different sides
553+
>>> m = nn.ReplicationPad3d((3, 3, 6, 6, 1, 1))
554+
>>> output = m(input)
555+
"""
556+
557+
padding: tuple[int, int, int, int, int, int]
558+
559+
def __init__(self, padding: _size_6_t) -> None:
560+
super().__init__()
561+
self.padding = _ntuple(6)(padding)

0 commit comments

Comments
 (0)