Skip to content

Commit 87b3452

Browse files
AsakusaRinneOceania2018
authored andcommitted
fix: error when using graph in multi-threads.
1 parent 854e3d7 commit 87b3452

File tree

2 files changed

+44
-2
lines changed

2 files changed

+44
-2
lines changed

src/TensorFlowNET.Core/Device/DeviceSpec.cs

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
using System;
2+
using System.Collections.Concurrent;
23
using System.Collections.Generic;
34
using System.Text;
45
using System.Threading.Tasks;
@@ -7,8 +8,8 @@ namespace Tensorflow.Device
78
{
89
public class DeviceSpec
910
{
10-
private static Dictionary<string, Components> _STRING_TO_COMPONENTS_CACHE = new();
11-
private static Dictionary<Components, string> _COMPONENTS_TO_STRING_CACHE = new();
11+
private static ConcurrentDictionary<string, Components> _STRING_TO_COMPONENTS_CACHE = new();
12+
private static ConcurrentDictionary<Components, string> _COMPONENTS_TO_STRING_CACHE = new();
1213
private string _job;
1314
private int _replica;
1415
private int _task;
Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
using Microsoft.VisualStudio.TestTools.UnitTesting;
2+
using System;
3+
using System.Collections.Generic;
4+
using System.IO;
5+
using System.Linq;
6+
using System.Text;
7+
using System.Threading;
8+
using System.Threading.Tasks;
9+
using Tensorflow;
10+
using static Tensorflow.Binding;
11+
12+
namespace TensorFlowNET.UnitTest.Basics
13+
{
14+
[TestClass]
15+
public class ThreadSafeTest
16+
{
17+
[TestMethod]
18+
public void GraphWithMultiThreads()
19+
{
20+
List<Thread> threads = new List<Thread>();
21+
22+
const int THREADS_COUNT = 5;
23+
24+
for (int t = 0; t < THREADS_COUNT; t++)
25+
{
26+
Thread thread = new Thread(() =>
27+
{
28+
Graph g = new Graph();
29+
Session session = new Session(g);
30+
session.as_default();
31+
var input = tf.placeholder(tf.int32, shape: new Shape(6));
32+
var op = tf.reshape(input, new int[] { 2, 3 });
33+
});
34+
thread.Start();
35+
threads.Add(thread);
36+
}
37+
38+
threads.ForEach(t => t.Join());
39+
}
40+
}
41+
}

0 commit comments

Comments
 (0)