@@ -15,143 +15,16 @@ def arange(start, end, step, dtype):
15
15
16
16
__all__ .append ('arange' )
17
17
18
- import math
19
- from typing import Tuple , Union
20
-
21
- def infer_broadcast_shape (input_shape : Tuple [int , ...],
22
- target_shape : Tuple [Union [int , None ], ...]) -> Tuple [int , ...]:
23
- """
24
- 推断 torch.broadcast_to 的输出形状
25
-
26
- 参数:
27
- input_shape: 输入张量的形状元组 (例如 (3, 1))
28
- target_shape: 目标广播形状元组 (可包含None表示自动推断维度)
29
-
30
- 返回:
31
- 广播后的输出形状元组
32
-
33
- 异常:
34
- ValueError: 当广播不兼容时
35
- """
36
- # 处理 None 值(自动维度推断)
37
- final_target_shape = []
38
- for i , dim in enumerate (target_shape ):
39
- if dim is None :
40
- # 查找可以推断的维度位置
41
- candidates = [j for j , d in enumerate (target_shape ) if d is None ]
42
- if len (candidates ) > 1 :
43
- raise ValueError (f"多个None维度 { candidates } ,无法明确推断" )
44
- final_target_shape .append (None )
45
- elif dim < - 1 :
46
- raise ValueError (f"维度大小不能为负数 (除-1外),发现 { dim } " )
47
- else :
48
- final_target_shape .append (dim )
49
-
50
- # 计算需要推断的总元素数量
51
- def count_product (shape , exclude_none = True ):
52
- prod = 1
53
- for dim in shape :
54
- if dim == 0 :
55
- return 0 # 任何维度为0结果即为0
56
- if dim is not None and not (exclude_none and dim == - 1 ):
57
- prod *= max (1 , dim ) # -1视为1用于计数
58
- return prod
59
-
60
- # 验证维度数量兼容性
61
- ndim_input = len (input_shape )
62
- ndim_target = len (final_target_shape )
63
-
64
- if ndim_input > ndim_target :
65
- raise ValueError (
66
- f"输入维度({ ndim_input } )多于目标维度({ ndim_target } ),"
67
- f"无法广播: { input_shape } -> { final_target_shape } "
68
- )
69
-
70
- # 创建对齐后的形状(左侧填充1)
71
- aligned_input_shape = (1 ,) * (ndim_target - ndim_input ) + input_shape
72
- inferred_target_shape = list (final_target_shape )
73
- known_product = 1
74
-
75
- # 第一遍:收集已知信息
76
- for i in range (ndim_target ):
77
- target_dim = inferred_target_shape [i ]
78
- input_dim = aligned_input_shape [i ]
79
-
80
- if target_dim == - 1 :
81
- # 标记需要推断的维度
82
- inferred_target_shape [i ] = None
83
- elif target_dim is not None :
84
- # 验证维度兼容性
85
- if target_dim == 0 :
86
- if input_dim not in (0 , 1 ):
87
- raise ValueError (
88
- f"维度 { i } : 目标维度为0时输入维度必须为0或1, "
89
- f"但得到 { input_dim } -> { target_dim } "
90
- )
91
- else : # 正数维度
92
- if input_dim != 1 and input_dim != target_dim :
93
- raise ValueError (
94
- f"维度 { i } : 大小 { input_dim } 无法广播到 { target_dim } "
95
- )
96
- known_product *= target_dim
97
-
98
- # 第二遍:推断维度
99
- total_elements = math .prod ([d for d in input_shape if d != 0 ])
100
- inferred_product = known_product
101
-
102
- # 统计需要推断的维度数量
103
- none_indices = [i for i , d in enumerate (inferred_target_shape ) if d is None ]
104
- num_infer = len (none_indices )
105
-
106
- if num_infer > 0 :
107
- # 计算需要推断的总元素量
108
- required_total = total_elements
109
-
110
- # 当输入有0维时的特殊情况
111
- if 0 in input_shape :
112
- if required_total != 0 :
113
- raise ValueError ("含0维输入广播时无法推断非0维度" )
114
- # 所有推断维度必须为0
115
- for i in none_indices :
116
- inferred_target_shape [i ] = 0
117
- else :
118
- if inferred_product == 0 and required_total > 0 :
119
- raise ValueError (
120
- "无法将非0输入广播到含0维的目标形状: "
121
- f"{ input_shape } -> { inferred_target_shape } "
122
- )
123
-
124
- # 计算推断维度的乘积
125
- infer_product = required_total // inferred_product if inferred_product != 0 else 0
126
-
127
- if infer_product * inferred_product != required_total :
128
- raise ValueError (
129
- f"元素总数不兼容: 输入有 { total_elements } 元素, "
130
- f"但目标形状仅能容纳 { inferred_product * infer_product } 元素"
131
- )
132
-
133
- # 检查是否可以整数划分
134
- for i in none_indices :
135
- # 仅当有1个-1时可以推断
136
- if num_infer == 1 :
137
- inferred_target_shape [i ] = infer_product
138
- else :
139
- # 多维度无法自动推断
140
- raise ValueError (
141
- f"多个维度({ len (none_indices )} )需要推断: { none_indices } "
142
- "但未指定足够约束条件"
143
- )
144
-
145
- # 转换为确定形状元组
146
- result_shape = tuple (
147
- d if d is not None else - 1 # 保留-1表示未指定
148
- for d in inferred_target_shape
149
- )
150
-
151
- return result_shape
152
-
153
18
def broadcast_to (input , shape ):
154
- out_shape = infer_broadcast_shape (input .shape , shape )
19
+ out_shape = ()
20
+ input_shape = input .shape
21
+ if len (input_shape ) != shape :
22
+ input_shape = (1 ,) + input_shape
23
+ for idx , s in enumerate (shape ):
24
+ if s == - 1 :
25
+ s = input_shape [idx ]
26
+ out_shape += (s ,)
27
+
155
28
out = Tensor_ (shape = out_shape , dtype = input .dtype )
156
29
return core .Tensor (out )
157
30
@@ -437,3 +310,34 @@ def squeeze(input, dim):
437
310
return core .Tensor (out )
438
311
439
312
__all__ .append ('squeeze' )
313
+
314
+ def exp (input ):
315
+ return input
316
+
317
+ __all__ .append ('exp' )
318
+
319
+ def rand_ext (size , seed , offset , dtype ):
320
+ out = Tensor_ (shape = size , dtype = dtype )
321
+ return core .Tensor (out )
322
+
323
+ __all__ .append ('rand_ext' )
324
+
325
+ def add (input , other ):
326
+ return input
327
+
328
+ __all__ .append ('add' )
329
+
330
+ def neg (input ):
331
+ return input
332
+
333
+ __all__ .append ('neg' )
334
+
335
+ def expm1 (input ):
336
+ return input
337
+
338
+ __all__ .append ('expm1' )
339
+
340
+ def reverse_v2 (input , dims ):
341
+ return input
342
+
343
+ __all__ .append ('reverse_v2' )
0 commit comments