@@ -314,8 +314,9 @@ impl<'data> SafeTensors<'data> {
314314 // if !string.starts_with('{') {
315315 // return Err(SafeTensorError::InvalidHeaderStart);
316316 // }
317- let metadata: Metadata = serde_json:: from_str ( string)
317+ let metadata: HashMetadata = serde_json:: from_str ( string)
318318 . map_err ( |_| SafeTensorError :: InvalidHeaderDeserialization ) ?;
319+ let metadata: Metadata = metadata. try_into ( ) ?;
319320 let buffer_end = metadata. validate ( ) ?;
320321 if buffer_end + 8 + n != buffer_len {
321322 return Err ( SafeTensorError :: MetadataIncompleteBuffer ) ;
@@ -442,20 +443,29 @@ struct HashMetadata {
442443 tensors : HashMap < String , TensorInfo > ,
443444}
444445
445- impl < ' de > Deserialize < ' de > for Metadata {
446- fn deserialize < D > ( deserializer : D ) -> Result < Self , D :: Error >
447- where
448- D : Deserializer < ' de > ,
449- {
450- let hashdata: HashMetadata = HashMetadata :: deserialize ( deserializer) ?;
446+ impl TryFrom < HashMetadata > for Metadata {
447+ type Error = SafeTensorError ;
448+ fn try_from ( hashdata : HashMetadata ) -> Result < Self , Self :: Error > {
451449 let ( metadata, tensors) = ( hashdata. metadata , hashdata. tensors ) ;
452450 let mut tensors: Vec < _ > = tensors. into_iter ( ) . collect ( ) ;
453451 // We need to sort by offsets
454452 // Previous versions might have a different ordering
455453 // Than we expect (Not aligned ordered, but purely name ordered,
456454 // or actually any order).
457455 tensors. sort_by ( |( _, left) , ( _, right) | left. data_offsets . cmp ( & right. data_offsets ) ) ;
458- Metadata :: new ( metadata, tensors) . map_err ( serde:: de:: Error :: custom)
456+ Metadata :: new ( metadata, tensors)
457+ }
458+ }
459+
460+ impl < ' de > Deserialize < ' de > for Metadata {
461+ fn deserialize < D > ( deserializer : D ) -> Result < Self , D :: Error >
462+ where
463+ D : Deserializer < ' de > ,
464+ {
465+ let hashdata: HashMetadata = HashMetadata :: deserialize ( deserializer) ?;
466+
467+ let metadata: Metadata = hashdata. try_into ( ) . map_err ( serde:: de:: Error :: custom) ?;
468+ Ok ( metadata)
459469 }
460470}
461471
@@ -487,7 +497,10 @@ impl Serialize for Metadata {
487497}
488498
489499impl Metadata {
490- fn new (
500+ /// Creates a new metadata structure.
501+ /// May fail if there is incorrect data in the Tensor Info.
502+ /// Notably the tensors need to be ordered by increasing data_offsets.
503+ pub fn new (
491504 metadata : Option < HashMap < String , String > > ,
492505 tensors : Vec < ( String , TensorInfo ) > ,
493506 ) -> Result < Self , SafeTensorError > {
@@ -507,7 +520,7 @@ impl Metadata {
507520 tensors,
508521 index_map,
509522 } ;
510- // metadata.validate()?;
523+ metadata. validate ( ) ?;
511524 Ok ( metadata)
512525 }
513526
@@ -1249,7 +1262,7 @@ mod tests {
12491262 Err ( SafeTensorError :: TensorInvalidInfo ) => {
12501263 // Yes we have the correct error
12511264 }
1252- _ => panic ! ( "This should not be able to be deserialized" ) ,
1265+ something => panic ! ( "This should not be able to be deserialized got {something:?} " ) ,
12531266 }
12541267 }
12551268
0 commit comments