@@ -10,7 +10,7 @@ use std::fmt;
1010use std:: cmp;
1111use std:: collections:: { HashMap , HashSet } ;
1212use std:: fs:: File ;
13- use std:: io;
13+ use std:: io:: { self , BufReader } ;
1414use std:: path:: Path ;
1515use std:: rc:: Rc ;
1616use std:: sync:: { Arc , RwLock } ;
@@ -455,7 +455,7 @@ impl<B: IBackend> Layer<B> {
455455 // reshape input tensor to the reshaped shape
456456 let old_shape = self . input_blobs_data [ input_i] . read ( ) . unwrap ( ) . desc ( ) . clone ( ) ;
457457 if old_shape. size ( ) != reshaped_shape. size ( ) {
458- panic ! ( "The provided input does not have the expected shape" ) ;
458+ panic ! ( "The provided input does not have the expected shape of {:?}" , reshaped_shape ) ;
459459 }
460460 self . input_blobs_data [ input_i] . write ( ) . unwrap ( ) . reshape ( & reshaped_shape) . unwrap ( ) ;
461461 }
@@ -583,6 +583,39 @@ impl<B: IBackend> Layer<B> {
583583 /// Serialize the Layer and it's weights to a Cap'n Proto file at the specified path.
584584 ///
585585 /// You can find the capnp schema [here](../../../../capnp/leaf.capnp).
586+ ///
587+ /// ```
588+ /// # #[cfg(feature = "native")]
589+ /// # mod native {
590+ /// # use std::rc::Rc;
591+ /// # use leaf::layer::*;
592+ /// # use leaf::layers::*;
593+ /// # use leaf::util;
594+ /// # pub fn test() {
595+ /// #
596+ /// let mut net_cfg = SequentialConfig::default();
597+ /// // ... set up network ...
598+ /// let cfg = LayerConfig::new("network", net_cfg);
599+ ///
600+ /// let native_backend = Rc::new(util::native_backend());
601+ /// let mut layer = Layer::from_config(native_backend, &cfg);
602+ /// // ... do stuff with the layer ...
603+ /// // ... and save it
604+ /// layer.save("mynetwork").unwrap();
605+ /// #
606+ /// # }}
607+ /// #
608+ /// # #[cfg(not(feature = "native"))]
609+ /// # mod native {
610+ /// # pub fn test() {}
611+ /// # }
612+ /// #
613+ /// # fn main() {
614+ /// # if cfg!(feature = "native") {
615+ /// # ::native::test();
616+ /// # }
617+ /// # }
618+ /// ```
586619 pub fn save < P : AsRef < Path > > ( & mut self , path : P ) -> io:: Result < ( ) > {
587620 let path = path. as_ref ( ) ;
588621 let ref mut out = try!( File :: create ( path) ) ;
@@ -597,6 +630,92 @@ impl<B: IBackend> Layer<B> {
597630 Ok ( ( ) )
598631 }
599632
633+ /// Read a Cap'n Proto file at the specified path and deserialize the Layer inside it.
634+ ///
635+ /// You can find the capnp schema [here](../../../../capnp/leaf.capnp).
636+ ///
637+ /// ```
638+ /// # extern crate leaf;
639+ /// # extern crate collenchyma;
640+ /// # #[cfg(feature = "native")]
641+ /// # mod native {
642+ /// # use std::rc::Rc;
643+ /// # use leaf::layer::*;
644+ /// # use leaf::layers::*;
645+ /// # use leaf::util;
646+ /// use collenchyma::prelude::*;
647+ /// # pub fn test() {
648+ ///
649+ /// let native_backend = Rc::new(util::native_backend());
650+ /// # let mut net_cfg = SequentialConfig::default();
651+ /// # let cfg = LayerConfig::new("network", net_cfg);
652+ /// # let mut layer = Layer::from_config(native_backend.clone(), &cfg);
653+ /// # layer.save("mynetwork").unwrap();
654+ /// // Load layer from file "mynetwork"
655+ /// let layer = Layer::<Backend<Native>>::load(native_backend, "mynetwork").unwrap();
656+ /// #
657+ /// # }}
658+ /// #
659+ /// # #[cfg(not(feature = "native"))]
660+ /// # mod native {
661+ /// # pub fn test() {}
662+ /// # }
663+ /// #
664+ /// # fn main() {
665+ /// # if cfg!(feature = "native") {
666+ /// # ::native::test();
667+ /// # }
668+ /// # }
669+ /// ```
670+ pub fn load < LB : IBackend + LayerOps < f32 > + ' static , P : AsRef < Path > > ( backend : Rc < LB > , path : P ) -> io:: Result < Layer < LB > > {
671+ let path = path. as_ref ( ) ;
672+ let ref mut file = try!( File :: open ( path) ) ;
673+ let mut reader = BufReader :: new ( file) ;
674+
675+ let message_reader = :: capnp:: serialize_packed:: read_message ( & mut reader,
676+ :: capnp:: message:: ReaderOptions :: new ( ) ) . unwrap ( ) ;
677+ let read_layer = message_reader. get_root :: < capnp_layer:: Reader > ( ) . unwrap ( ) ;
678+
679+ let name = read_layer. get_name ( ) . unwrap ( ) . to_owned ( ) ;
680+ let layer_config = LayerConfig :: read_capnp ( read_layer. get_config ( ) . unwrap ( ) ) ;
681+ let mut layer = Layer :: from_config ( backend, & layer_config) ;
682+ layer. name = name;
683+
684+ let read_weights = read_layer. get_weights_data ( ) . unwrap ( ) ;
685+
686+ let names = layer. learnable_weights_names ( ) ;
687+ let weights_data = layer. learnable_weights_data ( ) ;
688+
689+ let native_backend = Backend :: < Native > :: default ( ) . unwrap ( ) ;
690+ for ( i, ( name, weight) ) in names. iter ( ) . zip ( weights_data) . enumerate ( ) {
691+ for j in 0 ..read_weights. len ( ) {
692+ let capnp_weight = read_weights. get ( i as u32 ) ;
693+ if capnp_weight. get_name ( ) . unwrap ( ) != name {
694+ continue
695+ }
696+
697+ let mut weight_lock = weight. write ( ) . unwrap ( ) ;
698+ weight_lock. sync ( native_backend. device ( ) ) . unwrap ( ) ;
699+
700+ let capnp_tensor = capnp_weight. get_tensor ( ) . unwrap ( ) ;
701+ let mut shape = Vec :: new ( ) ;
702+ let capnp_shape = capnp_tensor. get_shape ( ) . unwrap ( ) ;
703+ for k in 0 ..capnp_shape. len ( ) {
704+ shape. push ( capnp_shape. get ( k) as usize )
705+ }
706+ weight_lock. reshape ( & shape) . unwrap ( ) ;
707+
708+ let mut native_slice = weight_lock. get_mut ( native_backend. device ( ) ) . unwrap ( ) . as_mut_native ( ) . unwrap ( ) . as_mut_slice :: < f32 > ( ) ;
709+ let data = capnp_tensor. get_data ( ) . unwrap ( ) ;
710+ for k in 0 ..data. len ( ) {
711+ native_slice[ k as usize ] = data. get ( k) ;
712+ }
713+ }
714+ }
715+
716+ Ok ( layer)
717+ }
718+
600719 /// Sets whether the layer should compute gradients w.r.t. a
601720 /// weight at a particular index given by `weight_id`.
602721 ///
@@ -672,6 +791,9 @@ impl<B: IBackend> Layer<B> {
672791 }
673792}
674793
794+ #[ allow( unsafe_code) ]
795+ unsafe impl < B : IBackend > Send for Layer < B > { }
796+
675797impl < ' a , B : IBackend > CapnpWrite < ' a > for Layer < B > {
676798 type Builder = capnp_layer:: Builder < ' a > ;
677799
@@ -1269,6 +1391,31 @@ impl<'a> CapnpWrite<'a> for LayerType {
12691391 }
12701392}
12711393
1394+ impl < ' a > CapnpRead < ' a > for LayerType {
1395+ type Reader = capnp_layer_type:: Reader < ' a > ;
1396+
1397+ fn read_capnp ( reader : Self :: Reader ) -> Self {
1398+ match reader. which ( ) . unwrap ( ) {
1399+ #[ cfg( all( feature="cuda" , not( feature="native" ) ) ) ]
1400+ capnp_layer_type:: Which :: Convolution ( read_config) => { let config = ConvolutionConfig :: read_capnp ( read_config. unwrap ( ) ) ; LayerType :: Convolution ( config) } ,
1401+ #[ cfg( not( all( feature="cuda" , not( feature="native" ) ) ) ) ]
1402+ capnp_layer_type:: Which :: Convolution ( _) => { panic ! ( "Can not load Network because Convolution layer is not supported with the used feature flags." ) } ,
1403+ capnp_layer_type:: Which :: Linear ( read_config) => { let config = LinearConfig :: read_capnp ( read_config. unwrap ( ) ) ; LayerType :: Linear ( config) } ,
1404+ capnp_layer_type:: Which :: LogSoftmax ( read_config) => { LayerType :: LogSoftmax } ,
1405+ #[ cfg( all( feature="cuda" , not( feature="native" ) ) ) ]
1406+ capnp_layer_type:: Which :: Pooling ( read_config) => { let config = PoolingConfig :: read_capnp ( read_config. unwrap ( ) ) ; LayerType :: Pooling ( config) } ,
1407+ #[ cfg( not( all( feature="cuda" , not( feature="native" ) ) ) ) ]
1408+ capnp_layer_type:: Which :: Pooling ( _) => { panic ! ( "Can not load Network because Pooling layer is not supported with the used feature flags." ) } ,
1409+ capnp_layer_type:: Which :: Sequential ( read_config) => { let config = SequentialConfig :: read_capnp ( read_config. unwrap ( ) ) ; LayerType :: Sequential ( config) } ,
1410+ capnp_layer_type:: Which :: Softmax ( _) => { LayerType :: Softmax } ,
1411+ capnp_layer_type:: Which :: Relu ( _) => { LayerType :: ReLU } ,
1412+ capnp_layer_type:: Which :: Sigmoid ( _) => { LayerType :: Sigmoid } ,
1413+ capnp_layer_type:: Which :: NegativeLogLikelihood ( read_config) => { let config = NegativeLogLikelihoodConfig :: read_capnp ( read_config. unwrap ( ) ) ; LayerType :: NegativeLogLikelihood ( config) } ,
1414+ capnp_layer_type:: Which :: Reshape ( read_config) => { let config = ReshapeConfig :: read_capnp ( read_config. unwrap ( ) ) ; LayerType :: Reshape ( config) } ,
1415+ }
1416+ }
1417+ }
1418+
12721419impl LayerConfig {
12731420 /// Creates a new LayerConfig
12741421 pub fn new < L : Into < LayerType > > ( name : & str , layer_type : L ) -> LayerConfig {
@@ -1338,9 +1485,13 @@ impl LayerConfig {
13381485 Err ( "propagate_down config must be specified either 0 or inputs_len times" )
13391486 }
13401487 }
1488+ }
1489+
1490+ impl < ' a > CapnpWrite < ' a > for LayerConfig {
1491+ type Builder = capnp_layer_config:: Builder < ' a > ;
13411492
13421493 /// Write the LayerConfig into a capnp message.
1343- pub fn write_capnp ( & self , builder : & mut capnp_layer_config :: Builder ) {
1494+ fn write_capnp ( & self , builder : & mut Self :: Builder ) {
13441495 builder. set_name ( & self . name ) ;
13451496 {
13461497 let mut layer_type = builder. borrow ( ) . init_layer_type ( ) ;
@@ -1373,3 +1524,44 @@ impl LayerConfig {
13731524 }
13741525 }
13751526}
1527+
1528+ impl < ' a > CapnpRead < ' a > for LayerConfig {
1529+ type Reader = capnp_layer_config:: Reader < ' a > ;
1530+
1531+ fn read_capnp ( reader : Self :: Reader ) -> Self {
1532+ let name = reader. get_name ( ) . unwrap ( ) . to_owned ( ) ;
1533+ let layer_type = LayerType :: read_capnp ( reader. get_layer_type ( ) ) ;
1534+
1535+ let read_outputs = reader. get_outputs ( ) . unwrap ( ) ;
1536+ let mut outputs = Vec :: new ( ) ;
1537+ for i in 0 ..read_outputs. len ( ) {
1538+ outputs. push ( read_outputs. get ( i) . unwrap ( ) . to_owned ( ) )
1539+ }
1540+ let read_inputs = reader. get_inputs ( ) . unwrap ( ) ;
1541+ let mut inputs = Vec :: new ( ) ;
1542+ for i in 0 ..read_inputs. len ( ) {
1543+ inputs. push ( read_inputs. get ( i) . unwrap ( ) . to_owned ( ) )
1544+ }
1545+
1546+ let read_params = reader. get_params ( ) . unwrap ( ) ;
1547+ let mut params = Vec :: new ( ) ;
1548+ for i in 0 ..read_params. len ( ) {
1549+ params. push ( WeightConfig :: read_capnp ( read_params. get ( i) ) )
1550+ }
1551+
1552+ let read_propagate_down = reader. get_propagate_down ( ) . unwrap ( ) ;
1553+ let mut propagate_down = Vec :: new ( ) ;
1554+ for i in 0 ..read_propagate_down. len ( ) {
1555+ propagate_down. push ( read_propagate_down. get ( i) )
1556+ }
1557+
1558+ LayerConfig {
1559+ name : name,
1560+ layer_type : layer_type,
1561+ outputs : outputs,
1562+ inputs : inputs,
1563+ params : params,
1564+ propagate_down : propagate_down,
1565+ }
1566+ }
1567+ }
0 commit comments