Skip to content

Commit e1e3395

Browse files
authored
Adding a public API for metadata. (#618)
* Adding a public API for metadata. * Validate doesn't need to be public.
1 parent bca53e3 commit e1e3395

File tree

1 file changed

+24
-11
lines changed

1 file changed

+24
-11
lines changed

safetensors/src/tensor.rs

Lines changed: 24 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -314,8 +314,9 @@ impl<'data> SafeTensors<'data> {
314314
// if !string.starts_with('{') {
315315
// return Err(SafeTensorError::InvalidHeaderStart);
316316
// }
317-
let metadata: Metadata = serde_json::from_str(string)
317+
let metadata: HashMetadata = serde_json::from_str(string)
318318
.map_err(|_| SafeTensorError::InvalidHeaderDeserialization)?;
319+
let metadata: Metadata = metadata.try_into()?;
319320
let buffer_end = metadata.validate()?;
320321
if buffer_end + 8 + n != buffer_len {
321322
return Err(SafeTensorError::MetadataIncompleteBuffer);
@@ -442,20 +443,29 @@ struct HashMetadata {
442443
tensors: HashMap<String, TensorInfo>,
443444
}
444445

445-
impl<'de> Deserialize<'de> for Metadata {
446-
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
447-
where
448-
D: Deserializer<'de>,
449-
{
450-
let hashdata: HashMetadata = HashMetadata::deserialize(deserializer)?;
446+
impl TryFrom<HashMetadata> for Metadata {
447+
type Error = SafeTensorError;
448+
fn try_from(hashdata: HashMetadata) -> Result<Self, Self::Error> {
451449
let (metadata, tensors) = (hashdata.metadata, hashdata.tensors);
452450
let mut tensors: Vec<_> = tensors.into_iter().collect();
453451
// We need to sort by offsets
454452
// Previous versions might have a different ordering
455453
// Than we expect (Not aligned ordered, but purely name ordered,
456454
// or actually any order).
457455
tensors.sort_by(|(_, left), (_, right)| left.data_offsets.cmp(&right.data_offsets));
458-
Metadata::new(metadata, tensors).map_err(serde::de::Error::custom)
456+
Metadata::new(metadata, tensors)
457+
}
458+
}
459+
460+
impl<'de> Deserialize<'de> for Metadata {
461+
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
462+
where
463+
D: Deserializer<'de>,
464+
{
465+
let hashdata: HashMetadata = HashMetadata::deserialize(deserializer)?;
466+
467+
let metadata: Metadata = hashdata.try_into().map_err(serde::de::Error::custom)?;
468+
Ok(metadata)
459469
}
460470
}
461471

@@ -487,7 +497,10 @@ impl Serialize for Metadata {
487497
}
488498

489499
impl Metadata {
490-
fn new(
500+
/// Creates a new metadata structure.
501+
/// May fail if there is incorrect data in the Tensor Info.
502+
/// Notably the tensors need to be ordered by increasing data_offsets.
503+
pub fn new(
491504
metadata: Option<HashMap<String, String>>,
492505
tensors: Vec<(String, TensorInfo)>,
493506
) -> Result<Self, SafeTensorError> {
@@ -507,7 +520,7 @@ impl Metadata {
507520
tensors,
508521
index_map,
509522
};
510-
// metadata.validate()?;
523+
metadata.validate()?;
511524
Ok(metadata)
512525
}
513526

@@ -1249,7 +1262,7 @@ mod tests {
12491262
Err(SafeTensorError::TensorInvalidInfo) => {
12501263
// Yes we have the correct error
12511264
}
1252-
_ => panic!("This should not be able to be deserialized"),
1265+
something => panic!("This should not be able to be deserialized got {something:?}"),
12531266
}
12541267
}
12551268

0 commit comments

Comments
 (0)