Skip to content

Commit 1a9e610

Browse files
committed
accessed by index #132
1 parent 8f42762 commit 1a9e610

File tree

6 files changed

+51
-24
lines changed

6 files changed

+51
-24
lines changed
Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,23 @@
11
using System;
2+
using System.Collections;
23
using System.Collections.Generic;
34
using System.Text;
45

56
namespace Tensorflow
67
{
7-
public class InputList
8+
public class InputList : IEnumerable
89
{
910
public Tensor[] _inputs;
11+
public Tensor this[int index] => _inputs[index];
1012

1113
public InputList(Tensor[] inputs)
1214
{
1315
_inputs = inputs;
1416
}
17+
18+
public IEnumerator GetEnumerator()
19+
{
20+
return _inputs.GetEnumerator();
21+
}
1522
}
1623
}
Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
using System;
2+
using System.Collections.Generic;
3+
using System.Text;
4+
5+
namespace Tensorflow
6+
{
7+
public partial class Operation
8+
{
9+
/// <summary>
10+
/// Add this op to its control flow context.
11+
/// </summary>
12+
public void _control_flow_post_processing()
13+
{
14+
foreach(var input_tensor in inputs)
15+
{
16+
17+
}
18+
}
19+
}
20+
}

src/TensorFlowNET.Core/Operations/Operation.Output.cs

Lines changed: 1 addition & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -12,21 +12,7 @@ public partial class Operation
1212
public int OutputListLength(string name) => c_api.TF_OperationOutputListLength(_handle, name, status);
1313

1414
private Tensor[] _outputs;
15-
public Tensor[] outputs
16-
{
17-
get
18-
{
19-
if (_outputs == null)
20-
{
21-
_outputs = new Tensor[NumOutputs];
22-
23-
for (int i = 0; i < NumOutputs; i++)
24-
_outputs[i] = new Tensor(this, i, OutputType(i));
25-
}
26-
27-
return _outputs;
28-
}
29-
}
15+
public Tensor[] outputs => _outputs;
3016

3117
public int NumControlOutputs => c_api.TF_OperationNumControlOutputs(_handle);
3218
public int OutputNumConsumers(int index) => c_api.TF_OperationOutputNumConsumers(new TF_Output(_handle, index));

src/TensorFlowNET.Core/Operations/Operation.cs

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ namespace Tensorflow
88
{
99
public partial class Operation
1010
{
11-
private readonly IntPtr _handle;
11+
private readonly IntPtr _handle; // _c_op in python
1212

1313
public Graph Graph { get; }
1414
public int _id => _id_value;
@@ -97,12 +97,20 @@ public Operation(NodeDef node_def, Graph g, List<Tensor> inputs = null, TF_DataT
9797

9898
_handle = ops._create_c_op(g, node_def, inputs, control_input_ops.ToArray());
9999

100+
// Initialize self._outputs.
100101
output_types = new TF_DataType[NumOutputs];
101102

102103
for (int i = 0; i < NumOutputs; i++)
103104
output_types[i] = OutputType(i);
104105

106+
_outputs = new Tensor[NumOutputs];
107+
for (int i = 0; i < NumOutputs; i++)
108+
_outputs[i] = new Tensor(this, i, OutputType(i));
109+
105110
Graph._add_op(this);
111+
112+
if (_handle != IntPtr.Zero)
113+
_control_flow_post_processing();
106114
}
107115

108116
public object get_attr<T>(string name)

src/TensorFlowNET.Core/Operations/gen_math_ops.cs

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ public static Tensor add(Tensor x, Tensor y)
1818

1919
var _op = _op_def_lib._apply_op_helper("Add", keywords: keywords);
2020

21-
return new Tensor(_op, 0, _op.OutputType(0));
21+
return _op.outputs[0];
2222
}
2323

2424
public static Tensor sub(Tensor x, Tensor y)
@@ -29,7 +29,7 @@ public static Tensor sub(Tensor x, Tensor y)
2929

3030
var _op = _op_def_lib._apply_op_helper("Sub", name: "sub", keywords: keywords);
3131

32-
return new Tensor(_op, 0, _op.OutputType(0));
32+
return _op.outputs[0];
3333
}
3434

3535
public static Tensor mul(Tensor x, Tensor y)
@@ -40,7 +40,7 @@ public static Tensor mul(Tensor x, Tensor y)
4040

4141
var _op = _op_def_lib._apply_op_helper("Mul", keywords: keywords);
4242

43-
return new Tensor(_op, 0, _op.OutputType(0));
43+
return _op.outputs[0];
4444
}
4545

4646
public static Tensor real_div(Tensor x, Tensor y)
@@ -51,7 +51,7 @@ public static Tensor real_div(Tensor x, Tensor y)
5151

5252
var _op = _op_def_lib._apply_op_helper("RealDiv", name: "truediv", keywords: keywords);
5353

54-
return new Tensor(_op, 0, _op.OutputType(0));
54+
return _op.outputs[0];
5555
}
5656

5757
public static Tensor mat_mul(Tensor a, Tensor b, bool transpose_a = false, bool transpose_b = false)
@@ -64,7 +64,7 @@ public static Tensor mat_mul(Tensor a, Tensor b, bool transpose_a = false, bool
6464

6565
var _op = _op_def_lib._apply_op_helper("MatMul", keywords: keywords);
6666

67-
return new Tensor(_op, 0, _op.OutputType(0));
67+
return _op.outputs[0];
6868
}
6969

7070
public static Tensor pow(Tensor x, double y)
@@ -75,7 +75,7 @@ public static Tensor pow(Tensor x, double y)
7575

7676
var _op = _op_def_lib._apply_op_helper("Pow", keywords: keywords);
7777

78-
return new Tensor(_op, 0, _op.OutputType(0));
78+
return _op.outputs[0];
7979
}
8080

8181
public static Tensor sum(Tensor input, Tensor axis = null)
@@ -87,7 +87,7 @@ public static Tensor sum(Tensor input, Tensor axis = null)
8787

8888
var _op = _op_def_lib._apply_op_helper("Sum", keywords: keywords);
8989

90-
return new Tensor(_op, 0, _op.OutputType(0));
90+
return _op.outputs[0];
9191
}
9292

9393
/// <summary>

src/TensorFlowNET.Core/Variables/RefVariable.Implicit.cs

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,12 @@ public static implicit operator Tensor(RefVariable var)
2323

2424
public static implicit operator RefVariable(Tensor var)
2525
{
26+
switch (var.dtype)
27+
{
28+
case TF_DataType.TF_INT32:
29+
return tf.Variable(var.Data<int>()[0]);
30+
}
31+
2632
return null;
2733
}
2834
}

0 commit comments

Comments
 (0)