Skip to content

Commit 879ecec

Browse files
committed
fix: wip
1 parent ea8458f commit 879ecec

File tree

3 files changed

+155
-60
lines changed

3 files changed

+155
-60
lines changed
Lines changed: 3 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,6 @@
1-
use anyhow::Result;
21
use serde::de::{SeqAccess, Visitor};
32
use serde::{Deserialize, Deserializer};
4-
use std::collections::HashMap;
53
use std::fmt::Formatter;
6-
use std::fs::File;
7-
use std::io::Read;
8-
use tracing::{debug, info, info_span};
94

105
#[derive(Debug, Copy, Clone)]
116
pub enum ProxyProtocol {
@@ -15,12 +10,12 @@ pub enum ProxyProtocol {
1510

1611
impl<'de> Deserialize<'de> for ProxyProtocol {
1712
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
18-
where
19-
D: Deserializer<'de>,
13+
where
14+
D: Deserializer<'de>,
2015
{
2116
match bool::deserialize(deserializer)? {
2217
true => Ok(ProxyProtocol::Enabled),
23-
false => Ok(ProxyProtocol::Disabled)
18+
false => Ok(ProxyProtocol::Disabled),
2419
}
2520
}
2621
}
@@ -77,55 +72,3 @@ impl<'de> Deserialize<'de> for Listener {
7772
deserializer.deserialize_any(ListenerVisitor)
7873
}
7974
}
80-
81-
#[derive(Deserialize, Debug)]
82-
pub struct Api {
83-
pub http: Listener,
84-
pub https: Listener,
85-
pub prom: Listener,
86-
}
87-
88-
fn default_acme() -> String {
89-
"https://acme-v02.api.letsencrypt.org/directory".to_string()
90-
}
91-
92-
#[derive(Deserialize, Debug)]
93-
pub struct General {
94-
pub dns: String,
95-
pub db: String,
96-
#[serde(default = "default_acme")]
97-
pub acme: String,
98-
pub name: String,
99-
}
100-
101-
#[derive(Deserialize, Debug)]
102-
pub struct Config {
103-
pub general: General,
104-
pub api: Api,
105-
#[serde(default)]
106-
pub records: HashMap<String, Vec<Vec<String>>>,
107-
}
108-
109-
const DEFAULT_CONFIG_PATH: &str = "config.toml";
110-
111-
// is not async so we can use it to load settings for tokio runtime
112-
pub fn load_config(config_path: Option<String>) -> Result<Config> {
113-
let config_path = config_path.as_deref().unwrap_or(DEFAULT_CONFIG_PATH);
114-
115-
let span = info_span!("load_config", config_path);
116-
let _enter = span.enter();
117-
118-
let mut file = File::open(config_path)?;
119-
debug!(?file, "Opened file");
120-
121-
let mut bytes = vec![];
122-
file.read_to_end(&mut bytes)?;
123-
debug!(file_length = bytes.len(), "Read file");
124-
125-
let config = toml::de::from_slice::<Config>(&bytes)?;
126-
// redact db information
127-
let config_str = format!("{:?}", config).replace(&config.general.db, "******");
128-
info!(config = %config_str, "Deserialized config");
129-
130-
Ok(config)
131-
}

src/config/mod.rs

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,63 @@
1+
use anyhow::Result;
2+
use serde::Deserialize;
3+
use std::collections::HashMap;
4+
use std::fs::File;
5+
use std::io::Read;
6+
use tracing::{debug, info, info_span};
7+
8+
pub use listener::{Listener, ProxyProtocol};
9+
10+
mod listener;
11+
mod records;
12+
13+
#[derive(Deserialize, Debug)]
14+
pub struct Api {
15+
pub http: Listener,
16+
pub https: Listener,
17+
pub prom: Listener,
18+
}
19+
20+
fn default_acme() -> String {
21+
"https://acme-v02.api.letsencrypt.org/directory".to_string()
22+
}
23+
24+
#[derive(Deserialize, Debug)]
25+
pub struct General {
26+
pub dns: String,
27+
pub db: String,
28+
#[serde(default = "default_acme")]
29+
pub acme: String,
30+
pub name: String,
31+
}
32+
33+
#[derive(Deserialize, Debug)]
34+
pub struct Config {
35+
pub general: General,
36+
pub api: Api,
37+
#[serde(default)]
38+
pub records: HashMap<String, Vec<Vec<String>>>,
39+
}
40+
41+
const DEFAULT_CONFIG_PATH: &str = "config.toml";
42+
43+
// is not async so we can use it to load settings for tokio runtime
44+
pub fn load_config(config_path: Option<String>) -> Result<Config> {
45+
let config_path = config_path.as_deref().unwrap_or(DEFAULT_CONFIG_PATH);
46+
47+
let span = info_span!("load_config", config_path);
48+
let _enter = span.enter();
49+
50+
let mut file = File::open(config_path)?;
51+
debug!(?file, "Opened file");
52+
53+
let mut bytes = vec![];
54+
file.read_to_end(&mut bytes)?;
55+
debug!(file_length = bytes.len(), "Read file");
56+
57+
let config = toml::de::from_slice::<Config>(&bytes)?;
58+
// redact db information
59+
let config_str = format!("{:?}", config).replace(&config.general.db, "******");
60+
info!(config = %config_str, "Deserialized config");
61+
62+
Ok(config)
63+
}

src/config/records.rs

Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,89 @@
1+
use serde::de::{DeserializeSeed, Error as DeError, MapAccess, SeqAccess, Visitor};
2+
use serde::{Deserialize, Deserializer};
3+
use std::collections::HashMap;
4+
use std::fmt::Formatter;
5+
use std::str::FromStr;
6+
use std::sync::Arc;
7+
use trust_dns_server::proto::rr::{Name, RecordSet, RecordType};
8+
9+
#[derive(Default)]
10+
struct PreconfiguredRecords(HashMap<Name, HashMap<&str, Arc<RecordSet>>>);
11+
12+
impl<'de> Deserialize<'de> for PreconfiguredRecords {
13+
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
14+
where
15+
D: Deserializer<'de>,
16+
{
17+
struct PreconfiguredRecordsVisitor;
18+
impl<'de> Visitor<'de> for PreconfiguredRecordsVisitor {
19+
type Value = PreconfiguredRecords;
20+
21+
fn expecting(&self, formatter: &mut Formatter) -> std::fmt::Result {
22+
formatter.write_str("PreconfiguredRecords")
23+
}
24+
25+
fn visit_map<A>(self, mut map: A) -> Result<Self::Value, A::Error>
26+
where
27+
A: MapAccess<'de>,
28+
{
29+
let mut res = HashMap::new();
30+
while let Some(key) = map.next_key::<&str>()? {
31+
let name = match Name::from_str(key) {
32+
Err(e) => return Err(DeError::custom(e)),
33+
Ok(name) => name,
34+
};
35+
36+
let (record_type, record_set) =
37+
map.next_value_seed(RecordDataSeed(name.clone()))?;
38+
39+
res.insert(record_type, record_set);
40+
}
41+
42+
Ok(PreconfiguredRecords(res))
43+
}
44+
}
45+
46+
deserializer.deserialize_map(PreconfiguredRecordsVisitor)
47+
}
48+
}
49+
50+
struct RecordDataSeed(Name);
51+
52+
impl<'de> DeserializeSeed<'de> for RecordDataSeed {
53+
type Value = (RecordType, Arc<RecordSet>);
54+
55+
fn deserialize<D>(self, deserializer: D) -> Result<Self::Value, D::Error>
56+
where
57+
D: Deserializer<'de>,
58+
{
59+
struct RecordDataVisitor(Name);
60+
impl<'de> Visitor<'de> for RecordDataVisitor {
61+
type Value = (RecordType, Arc<RecordSet>);
62+
63+
fn expecting(&self, formatter: &mut Formatter) -> std::fmt::Result {
64+
formatter.write_str("RecordData")
65+
}
66+
67+
fn visit_seq<A>(self, mut seq: A) -> Result<Self::Value, A::Error>
68+
where
69+
A: SeqAccess<'de>,
70+
{
71+
let record_type = match seq.next_element::<&str>()? {
72+
Some(record_type) => record_type,
73+
None => return Err(DeError::custom("Could not find RecordType")),
74+
};
75+
76+
let ttl = match seq.next_element::<u32>()? {
77+
Some(ttl) => ttl,
78+
None => return Err(DeError::custom("Could not find TTL")),
79+
};
80+
81+
let record_set = Arc::new(RecordSet::with_ttl(self.0, record_type, ttl));
82+
83+
Ok((record_type, record_set))
84+
}
85+
}
86+
87+
deserializer.deserialize_seq(RecordDataVisitor(self.0))
88+
}
89+
}

0 commit comments

Comments
 (0)