Skip to content

Commit e278e31

Browse files
committed
Update demo
1 parent cec2cf2 commit e278e31

File tree

7 files changed

+37
-25
lines changed

7 files changed

+37
-25
lines changed

apps/sgx/Cargo.toml

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,3 @@ edition = "2018"
66

77
[dependencies]
88
tvm-runtime = { path = "../../rust/runtime" }
9-
10-
[build-dependencies]
11-
ar = "0.6"

apps/sgx/build.rs

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,17 +7,25 @@ fn main() {
77
.output()
88
.expect("Failed to execute command");
99
assert!(
10-
std::path::Path::new(&format!("{}/model.o", out_dir)).exists(),
10+
["model.o", "graph.json", "params.bin"].iter().all(|f| {
11+
std::path::Path::new(&format!("{}/{}", out_dir, f)).exists()
12+
}),
1113
"Could not build tvm lib: {}",
1214
String::from_utf8(output.stderr).unwrap().trim()
1315
);
1416

15-
std::process::Command::new("llvm-ar-8")
17+
std::process::Command::new("objcopy")
18+
.arg("--globalize-symbol=__tvm_module_startup")
19+
.arg(&format!("{}/model.o", out_dir))
20+
.output()
21+
.expect("Could not gloablize startup function.");
22+
23+
std::process::Command::new("llvm-ar")
1624
.arg("rcs")
1725
.arg(&format!("{}/libmodel.a", out_dir))
1826
.arg(&format!("{}/model.o", out_dir))
1927
.output()
20-
.expect("Failed to execute command");
28+
.expect("Failed to package model archive.");
2129

2230
println!("cargo:rustc-link-lib=static=model");
2331
println!("cargo:rustc-link-search=native={}", out_dir);

apps/sgx/src/build_model.py

Lines changed: 10 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -6,30 +6,28 @@
66
from os import path as osp
77
import sys
88

9-
import nnvm.compiler
10-
import nnvm.testing
119
import tvm
10+
from tvm import relay
11+
import tvm.relay.testing
1212

1313

1414
def main():
15-
# from tutorials/nnvm_quick_start.py
16-
dshape = (1, 3, 224, 224)
17-
net, params = nnvm.testing.resnet.get_workload(
18-
layers=18, batch_size=dshape[0], image_shape=dshape[1:])
15+
dshape = (1, 28, 28)
16+
net, params = relay.testing.mlp.get_workload(batch_size=dshape[0], dtype='float32')
1917

20-
with nnvm.compiler.build_config(opt_level=3):
21-
graph, lib, params = nnvm.compiler.build(
22-
net, 'llvm --system-lib', shape={'data': dshape}, params=params)
18+
with relay.build_config(opt_level=3):
19+
graph, lib, params = relay.build_module.build(
20+
net, target='llvm --system-lib', params=params)
2321

2422
build_dir = osp.abspath(sys.argv[1])
2523
if not osp.isdir(build_dir):
2624
os.makedirs(build_dir, exist_ok=True)
2725

2826
lib.save(osp.join(build_dir, 'model.o'))
2927
with open(osp.join(build_dir, 'graph.json'), 'w') as f_graph_json:
30-
f_graph_json.write(graph.json())
31-
with open(osp.join(build_dir, 'params.bin'), 'wb') as f_params:
32-
f_params.write(nnvm.compiler.save_param_dict(params))
28+
f_graph_json.write(graph)
29+
with open(osp.join(build_dir, 'params.bin'), 'wb') as f_params:
30+
f_params.write(relay.save_param_dict(params))
3331

3432

3533
if __name__ == '__main__':

apps/sgx/src/main.rs

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -22,10 +22,12 @@ fn main() {
2222
let listener = std::net::TcpListener::bind("127.0.0.1:4242").unwrap();
2323
for stream in listener.incoming() {
2424
let mut stream = stream.unwrap();
25-
stream.read_exact(input_bytes.as_mut_slice()).unwrap();
25+
if let Err(_) = stream.read_exact(input_bytes.as_mut_slice()) {
26+
continue;
27+
}
2628
exec.run();
27-
stream
28-
.write_all(exec.get_output(0).unwrap().data().as_slice())
29-
.unwrap();
29+
if let Err(_) = stream.write_all(exec.get_output(0).unwrap().data().as_slice()) {
30+
continue;
31+
}
3032
}
3133
}

python/tvm/relay/testing/mlp.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ def get_net(batch_size,
5050
dtype=dtype)
5151
data = relay.nn.batch_flatten(data)
5252
fc1 = relay.nn.dense(data, relay.var("fc1_weight"), units=128)
53-
fc1 = relay.nn.bias_add(fc1, relay.var("fc2_bias"))
53+
fc1 = relay.nn.bias_add(fc1, relay.var("fc1_bias"))
5454
act1 = relay.nn.relu(fc1)
5555
fc2 = relay.nn.dense(act1, relay.var("fc2_weight"), units=64)
5656
fc2 = relay.nn.bias_add(fc2, relay.var("fc2_bias"))

rust/runtime/src/module.rs

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,11 @@ pub trait Module {
1313

1414
pub struct SystemLibModule;
1515

16+
#[cfg(target_env = "sgx")]
17+
extern "C" {
18+
fn __tvm_module_startup();
19+
}
20+
1621
lazy_static! {
1722
static ref SYSTEM_LIB_FUNCTIONS: Mutex<HashMap<String, &'static (dyn PackedFunc)>> =
1823
Mutex::new(HashMap::new());
@@ -30,6 +35,8 @@ impl Module for SystemLibModule {
3035

3136
impl Default for SystemLibModule {
3237
fn default() -> Self {
38+
#[cfg(target_env = "sgx")]
39+
unsafe { __tvm_module_startup(); }
3340
SystemLibModule {}
3441
}
3542
}

rust/runtime/src/threading.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -167,7 +167,7 @@ fn max_concurrency() -> usize {
167167
return threads;
168168
}
169169
}
170-
num_cpus::get_physical()
170+
num_cpus::get()
171171
}
172172

173173
#[cfg(target_arch = "wasm32")]
@@ -181,7 +181,7 @@ pub extern "C" fn TVMBackendParallelLaunch(
181181
cdata: *const c_void,
182182
num_task: usize,
183183
) -> c_int {
184-
if max_concurrency() == 0 {
184+
if max_concurrency() < 2 {
185185
let penv = TVMParallelGroupEnv {
186186
sync_handle: 0 as *mut c_void,
187187
num_task: 1,

0 commit comments

Comments
 (0)