Skip to content

Commit 4173f14

Browse files
authored
fix apis for f class (#2153)
1 parent d316cf3 commit 4173f14

File tree

6 files changed

+79
-141
lines changed

6 files changed

+79
-141
lines changed

mindnlp/core/_prims/meta.py

Lines changed: 40 additions & 136 deletions
Original file line numberDiff line numberDiff line change
@@ -15,143 +15,16 @@ def arange(start, end, step, dtype):
1515

1616
__all__.append('arange')
1717

18-
import math
19-
from typing import Tuple, Union
20-
21-
def infer_broadcast_shape(input_shape: Tuple[int, ...],
22-
target_shape: Tuple[Union[int, None], ...]) -> Tuple[int, ...]:
23-
"""
24-
推断 torch.broadcast_to 的输出形状
25-
26-
参数:
27-
input_shape: 输入张量的形状元组 (例如 (3, 1))
28-
target_shape: 目标广播形状元组 (可包含None表示自动推断维度)
29-
30-
返回:
31-
广播后的输出形状元组
32-
33-
异常:
34-
ValueError: 当广播不兼容时
35-
"""
36-
# 处理 None 值(自动维度推断)
37-
final_target_shape = []
38-
for i, dim in enumerate(target_shape):
39-
if dim is None:
40-
# 查找可以推断的维度位置
41-
candidates = [j for j, d in enumerate(target_shape) if d is None]
42-
if len(candidates) > 1:
43-
raise ValueError(f"多个None维度 {candidates},无法明确推断")
44-
final_target_shape.append(None)
45-
elif dim < -1:
46-
raise ValueError(f"维度大小不能为负数 (除-1外),发现 {dim}")
47-
else:
48-
final_target_shape.append(dim)
49-
50-
# 计算需要推断的总元素数量
51-
def count_product(shape, exclude_none=True):
52-
prod = 1
53-
for dim in shape:
54-
if dim == 0:
55-
return 0 # 任何维度为0结果即为0
56-
if dim is not None and not (exclude_none and dim == -1):
57-
prod *= max(1, dim) # -1视为1用于计数
58-
return prod
59-
60-
# 验证维度数量兼容性
61-
ndim_input = len(input_shape)
62-
ndim_target = len(final_target_shape)
63-
64-
if ndim_input > ndim_target:
65-
raise ValueError(
66-
f"输入维度({ndim_input})多于目标维度({ndim_target}),"
67-
f"无法广播: {input_shape} -> {final_target_shape}"
68-
)
69-
70-
# 创建对齐后的形状(左侧填充1)
71-
aligned_input_shape = (1,) * (ndim_target - ndim_input) + input_shape
72-
inferred_target_shape = list(final_target_shape)
73-
known_product = 1
74-
75-
# 第一遍:收集已知信息
76-
for i in range(ndim_target):
77-
target_dim = inferred_target_shape[i]
78-
input_dim = aligned_input_shape[i]
79-
80-
if target_dim == -1:
81-
# 标记需要推断的维度
82-
inferred_target_shape[i] = None
83-
elif target_dim is not None:
84-
# 验证维度兼容性
85-
if target_dim == 0:
86-
if input_dim not in (0, 1):
87-
raise ValueError(
88-
f"维度 {i}: 目标维度为0时输入维度必须为0或1, "
89-
f"但得到 {input_dim} -> {target_dim}"
90-
)
91-
else: # 正数维度
92-
if input_dim != 1 and input_dim != target_dim:
93-
raise ValueError(
94-
f"维度 {i}: 大小 {input_dim} 无法广播到 {target_dim}"
95-
)
96-
known_product *= target_dim
97-
98-
# 第二遍:推断维度
99-
total_elements = math.prod([d for d in input_shape if d != 0])
100-
inferred_product = known_product
101-
102-
# 统计需要推断的维度数量
103-
none_indices = [i for i, d in enumerate(inferred_target_shape) if d is None]
104-
num_infer = len(none_indices)
105-
106-
if num_infer > 0:
107-
# 计算需要推断的总元素量
108-
required_total = total_elements
109-
110-
# 当输入有0维时的特殊情况
111-
if 0 in input_shape:
112-
if required_total != 0:
113-
raise ValueError("含0维输入广播时无法推断非0维度")
114-
# 所有推断维度必须为0
115-
for i in none_indices:
116-
inferred_target_shape[i] = 0
117-
else:
118-
if inferred_product == 0 and required_total > 0:
119-
raise ValueError(
120-
"无法将非0输入广播到含0维的目标形状: "
121-
f"{input_shape} -> {inferred_target_shape}"
122-
)
123-
124-
# 计算推断维度的乘积
125-
infer_product = required_total // inferred_product if inferred_product != 0 else 0
126-
127-
if infer_product * inferred_product != required_total:
128-
raise ValueError(
129-
f"元素总数不兼容: 输入有 {total_elements} 元素, "
130-
f"但目标形状仅能容纳 {inferred_product * infer_product} 元素"
131-
)
132-
133-
# 检查是否可以整数划分
134-
for i in none_indices:
135-
# 仅当有1个-1时可以推断
136-
if num_infer == 1:
137-
inferred_target_shape[i] = infer_product
138-
else:
139-
# 多维度无法自动推断
140-
raise ValueError(
141-
f"多个维度({len(none_indices)})需要推断: {none_indices} "
142-
"但未指定足够约束条件"
143-
)
144-
145-
# 转换为确定形状元组
146-
result_shape = tuple(
147-
d if d is not None else -1 # 保留-1表示未指定
148-
for d in inferred_target_shape
149-
)
150-
151-
return result_shape
152-
15318
def broadcast_to(input, shape):
154-
out_shape = infer_broadcast_shape(input.shape, shape)
19+
out_shape = ()
20+
input_shape = input.shape
21+
if len(input_shape) != shape:
22+
input_shape = (1,) + input_shape
23+
for idx, s in enumerate(shape):
24+
if s == -1:
25+
s = input_shape[idx]
26+
out_shape += (s,)
27+
15528
out = Tensor_(shape=out_shape, dtype=input.dtype)
15629
return core.Tensor(out)
15730

@@ -437,3 +310,34 @@ def squeeze(input, dim):
437310
return core.Tensor(out)
438311

439312
__all__.append('squeeze')
313+
314+
def exp(input):
315+
return input
316+
317+
__all__.append('exp')
318+
319+
def rand_ext(size, seed, offset, dtype):
320+
out = Tensor_(shape=size, dtype=dtype)
321+
return core.Tensor(out)
322+
323+
__all__.append('rand_ext')
324+
325+
def add(input, other):
326+
return input
327+
328+
__all__.append('add')
329+
330+
def neg(input):
331+
return input
332+
333+
__all__.append('neg')
334+
335+
def expm1(input):
336+
return input
337+
338+
__all__.append('expm1')
339+
340+
def reverse_v2(input, dims):
341+
return input
342+
343+
__all__.append('reverse_v2')

mindnlp/core/_prims/numpy.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -785,3 +785,27 @@ def linalg_vector_norm(input, p, dim, keepdim, dtype):
785785
return core.Tensor.from_numpy(out)
786786

787787
__all__.append('linalg_vector_norm')
788+
789+
def exp(input):
790+
out = np.exp(input.numpy())
791+
return core.Tensor.from_numpy(out)
792+
793+
__all__.append('exp')
794+
795+
def expm1(input):
796+
out = np.expm1(input.numpy())
797+
return core.Tensor.from_numpy(out)
798+
799+
__all__.append('expm1')
800+
801+
def ones_like(input):
802+
out = np.ones_like(input.numpy())
803+
return core.Tensor.from_numpy(out)
804+
805+
__all__.append('ones_like')
806+
807+
def reverse_v2(input, dims):
808+
out = np.flip(input.numpy(), dims)
809+
return core.Tensor.from_numpy(out)
810+
811+
__all__.append('reverse_v2')

mindnlp/core/_tensor.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -381,7 +381,7 @@ def __round__(self):
381381

382382
def new(self, *shape):
383383
if not isinstance(shape[0], int):
384-
return tensor(shape[0], dtype=self.dtype)
384+
return tensor(shape[0], dtype=self.dtype, device=self.device)
385385
return ops.empty(*shape, dtype=self.dtype, device=self.device)
386386

387387
# Tensor.new_tensor

mindnlp/core/fft/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ def irfft(input, n=None, dim=-1, norm="backward"):
2929
# return _irfft(input)
3030

3131
def fftn(input, s=None, dim=None, norm=None):
32-
return ops.fftn(input, s, dim, norm)
32+
return execute('fftn', input, s, dim, norm)
3333

3434
def fft(input, s=None, dim=-1, norm=None):
3535
return ops.fft(input, s, dim, norm)

mindnlp/core/ops/creation.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,13 @@ def zeros(*size, out=None, dtype=None, layout=None, device=None, requires_grad=F
4242
if isinstance(size[0], (tuple, list)):
4343
size = size[0]
4444

45-
output = execute('zeros', size, dtype, device=device, requires_grad=requires_grad, user_created=True)
45+
new_size = ()
46+
for s in size:
47+
if not isinstance(s, int):
48+
s = s.item()
49+
new_size += (s,)
50+
51+
output = execute('zeros', new_size, dtype, device=device, requires_grad=requires_grad, user_created=True)
4652
if out is None:
4753
return output
4854
out.data = output

mindnlp/core/ops/other.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -710,10 +710,14 @@ def repeat_interleave(input, repeats, dim=None, *, output_size=None):
710710

711711
dim = dim + input.ndim if dim < 0 else dim
712712

713+
714+
if sum(repeats) == 0:
715+
out_shape = list(input.shape)
716+
out_shape[dim] = 0
717+
return core.Tensor(shape=tuple(out_shape), dtype=input.dtype)
718+
713719
if len(repeats) == 1:
714720
repeats = repeats[0]
715-
if repeats == 0:
716-
return Tensor_(input.dtype, (0,))
717721
if input.dtype == mindspore.bool_:
718722
input = input.to(mindspore.int32)
719723
out = execute('repeat_elements', input, repeats, dim)

0 commit comments

Comments
 (0)