|
2 | 2 |
|
3 | 3 | use crate::{Identity, error::RecordError}; |
4 | 4 | use atrium_api::types::string::{Cid, Did, Nsid, RecordKey}; |
5 | | -use reqwest::Client; |
| 5 | +use reqwest::{Client, StatusCode}; |
6 | 6 | use serde::{Deserialize, Serialize}; |
7 | 7 | use serde_json::value::RawValue; |
8 | 8 | use std::str::FromStr; |
@@ -56,6 +56,13 @@ struct RecordResponseObject { |
56 | 56 | value: Box<RawValue>, |
57 | 57 | } |
58 | 58 |
|
| 59 | +#[derive(Debug, Deserialize)] |
| 60 | +pub struct ErrorResponseObject { |
| 61 | + error: String, |
| 62 | + #[allow(dead_code)] |
| 63 | + message: String, |
| 64 | +} |
| 65 | + |
59 | 66 | #[derive(Clone)] |
60 | 67 | pub struct Repo { |
61 | 68 | identity: Identity, |
@@ -87,39 +94,63 @@ impl Repo { |
87 | 94 | return Err(RecordError::NotFound("could not get pds for DID")); |
88 | 95 | }; |
89 | 96 |
|
90 | | - // TODO: throttle by host probably, generally guard against outgoing requests |
| 97 | + // cid gets set to None for a retry, if it's Some and we got NotFound |
| 98 | + let mut cid = cid; |
91 | 99 |
|
92 | | - let mut params = vec![ |
93 | | - ("repo", did.to_string()), |
94 | | - ("collection", collection.to_string()), |
95 | | - ("rkey", rkey.to_string()), |
96 | | - ]; |
97 | | - if let Some(cid) = cid { |
98 | | - params.push(("cid", cid.as_ref().to_string())); |
99 | | - } |
100 | | - let mut url = Url::parse_with_params(&pds, ¶ms)?; |
101 | | - url.set_path("/xrpc/com.atproto.repo.getRecord"); |
| 100 | + let res = loop { |
| 101 | + // TODO: throttle outgoing requests by host probably, generally guard against outgoing requests |
| 102 | + let mut params = vec![ |
| 103 | + ("repo", did.to_string()), |
| 104 | + ("collection", collection.to_string()), |
| 105 | + ("rkey", rkey.to_string()), |
| 106 | + ]; |
| 107 | + if let Some(cid) = cid { |
| 108 | + params.push(("cid", cid.as_ref().to_string())); |
| 109 | + } |
| 110 | + let mut url = Url::parse_with_params(&pds, ¶ms)?; |
| 111 | + url.set_path("/xrpc/com.atproto.repo.getRecord"); |
102 | 112 |
|
103 | | - let res = self |
104 | | - .client |
105 | | - .get(url) |
106 | | - .send() |
107 | | - .await |
108 | | - .map_err(RecordError::SendError)? |
| 113 | + let res = self |
| 114 | + .client |
| 115 | + .get(url.clone()) |
| 116 | + .send() |
| 117 | + .await |
| 118 | + .map_err(RecordError::SendError)?; |
| 119 | + |
| 120 | + if res.status() == StatusCode::BAD_REQUEST { |
| 121 | + // 1. if we're not able to parse json, it's not something we can handle |
| 122 | + let err = res |
| 123 | + .json::<ErrorResponseObject>() |
| 124 | + .await |
| 125 | + .map_err(RecordError::UpstreamBadBadNotGoodRequest)?; |
| 126 | + // 2. if we are, is it a NotFound? and if so, did we try with a CID? |
| 127 | + // if so, retry with no CID (api handler will reject for mismatch but |
| 128 | + // with a nice error + warm cache) |
| 129 | + if err.error == "NotFound" && cid.is_some() { |
| 130 | + cid = &None; |
| 131 | + continue; |
| 132 | + } else { |
| 133 | + return Err(RecordError::UpstreamBadRequest(err)); |
| 134 | + } |
| 135 | + } |
| 136 | + break res; |
| 137 | + }; |
| 138 | + |
| 139 | + let data = res |
109 | 140 | .error_for_status() |
110 | 141 | .map_err(RecordError::StatusError)? // TODO atproto error handling (think about handling not found) |
111 | 142 | .json::<RecordResponseObject>() |
112 | 143 | .await |
113 | 144 | .map_err(RecordError::ParseJsonError)?; // todo... |
114 | 145 |
|
115 | | - let Some(cid) = res.cid else { |
| 146 | + let Some(cid) = data.cid else { |
116 | 147 | return Err(RecordError::MissingUpstreamCid); |
117 | 148 | }; |
118 | 149 | let cid = Cid::from_str(&cid).map_err(|e| RecordError::BadUpstreamCid(e.to_string()))?; |
119 | 150 |
|
120 | 151 | Ok(CachedRecord::Found(RawRecord { |
121 | 152 | cid, |
122 | | - record: res.value.to_string(), |
| 153 | + record: data.value.to_string(), |
123 | 154 | })) |
124 | 155 | } |
125 | 156 | } |
0 commit comments