Skip to content
This repository has been archived by the owner on Mar 23, 2021. It is now read-only.

Commit

Permalink
Clean up (de)serialization of allowed origins for CORS
Browse files Browse the repository at this point in the history
The solution could be simpler if
toml-rs/toml-rs#334 were fixed. It
would be even simpler if `serde` supported flattening enum variants,
but it doesn't: serde-rs/serde#1402.

Also:

- Rename config file field from `allowed_foreign_origins` to
  `allowed_origins` since being under the CORS section already
  indicates that it refers to foreign origins.
  • Loading branch information
luckysori committed Nov 15, 2019
1 parent 5814ff1 commit 5d6023d
Show file tree
Hide file tree
Showing 5 changed files with 68 additions and 31 deletions.
48 changes: 31 additions & 17 deletions cnd/src/config/file.rs
Original file line number Diff line number Diff line change
Expand Up @@ -42,12 +42,14 @@ impl File {
address: IpAddr::V4(Ipv4Addr::UNSPECIFIED),
port: 8000,
},
cors: None,
cors: Some(Cors {
allowed_origins: AllowedOrigins::None(None::None),
}),
},
database: None,
logging: None,
bitcoin: None,
ethereum: None,
database: Option::None,
logging: Option::None,
bitcoin: Option::None,
ethereum: Option::None,
}
}
}
Expand All @@ -71,15 +73,27 @@ pub struct HttpApi {

#[derive(Clone, Debug, Deserialize, Serialize, PartialEq)]
pub struct Cors {
pub allowed_foreign_origins: AllowedForeignOrigins,
pub allowed_origins: AllowedOrigins,
}

#[derive(Clone, Debug, Deserialize, Serialize, PartialEq)]
#[serde(untagged)]
pub enum AllowedOrigins {
All(All),
None(None),
Some(Vec<String>),
}

#[derive(Clone, Debug, Deserialize, Serialize, PartialEq)]
#[serde(rename_all = "lowercase")]
pub enum AllowedForeignOrigins {
pub enum All {
All,
}

#[derive(Clone, Debug, Deserialize, Serialize, PartialEq)]
#[serde(rename_all = "lowercase")]
pub enum None {
None,
List(Vec<String>),
}

#[derive(Clone, Debug, Deserialize, Serialize, PartialEq)]
Expand Down Expand Up @@ -161,7 +175,7 @@ impl File {

fn ensure_directory_exists(config_file: &Path) -> Result<(), config_rs::ConfigError> {
match config_file.parent() {
None => Ok(()),
Option::None => Ok(()),
Some(path) => {
if !path.exists() {
println!(
Expand Down Expand Up @@ -228,7 +242,7 @@ mod tests {
assert_that(&config_file).is_ok_containing(LoggingOnlyConfig {
logging: Logging {
level: Some(LevelFilter::Debug),
structured: None,
structured: Option::None,
},
});
}
Expand Down Expand Up @@ -278,25 +292,25 @@ mod tests {
fn cors_deserializes_correctly() {
let file_contents = vec![
r#"
allowed_foreign_origins = "all"
allowed_origins = "all"
"#,
r#"
allowed_foreign_origins = "none"
allowed_origins = "none"
"#,
r#"
allowed_foreign_origins = { list = ["http://localhost:8000", "https://192.168.1.55:3000"] }
allowed_origins = ["http://localhost:8000", "https://192.168.1.55:3000"]
"#,
];

let expected = vec![
Cors {
allowed_foreign_origins: AllowedForeignOrigins::All,
allowed_origins: AllowedOrigins::All(All::All),
},
Cors {
allowed_foreign_origins: AllowedForeignOrigins::None,
allowed_origins: AllowedOrigins::None(None::None),
},
Cors {
allowed_foreign_origins: AllowedForeignOrigins::List(vec![
allowed_origins: AllowedOrigins::Some(vec![
String::from("http://localhost:8000"),
String::from("https://192.168.1.55:3000"),
]),
Expand All @@ -319,7 +333,7 @@ mod tests {
#[test]
fn complete_logging_section_is_optional() {
let config_without_logging_section = File {
logging: None,
logging: Option::None,
..File::default()
};
let temp_file = temp_toml_file();
Expand Down
4 changes: 2 additions & 2 deletions cnd/src/config/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,6 @@ mod serde_duration;
mod settings;

pub use self::{
file::{AllowedForeignOrigins, File},
settings::Settings,
file::File,
settings::{AllowedOrigins, Settings},
};
33 changes: 28 additions & 5 deletions cnd/src/config/settings.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use super::file::{AllowedForeignOrigins, Cors, Database, File, Network, Socket};
use super::file::{self, Database, File, Network, Socket};
use crate::config::file::{Bitcoin, Ethereum};
use log::LevelFilter;
use reqwest::Url;
Expand All @@ -25,6 +25,18 @@ pub struct HttpApi {
pub cors: Cors,
}

#[derive(Clone, Debug, PartialEq)]
pub struct Cors {
pub allowed_origins: AllowedOrigins,
}

#[derive(Clone, Debug, PartialEq)]
pub enum AllowedOrigins {
All,
None,
Some(Vec<String>),
}

#[derive(Clone, Debug, PartialEq, derivative::Derivative)]
#[derivative(Default)]
pub struct Logging {
Expand All @@ -48,9 +60,20 @@ impl Settings {
network,
http_api: HttpApi {
socket: http_api.socket,
cors: http_api.cors.unwrap_or(Cors {
allowed_foreign_origins: AllowedForeignOrigins::None,
}),
cors: http_api
.cors
.map(|cors| {
let allowed_origins = match cors.allowed_origins {
file::AllowedOrigins::All(_) => AllowedOrigins::All,
file::AllowedOrigins::None(_) => AllowedOrigins::None,
file::AllowedOrigins::Some(origins) => AllowedOrigins::Some(origins),
};

Cors { allowed_origins }
})
.unwrap_or(Cors {
allowed_origins: AllowedOrigins::None,
}),
},
database,
logging: {
Expand Down Expand Up @@ -155,7 +178,7 @@ mod tests {
assert_that(&settings)
.map(|settings| &settings.http_api.cors)
.is_equal_to(Cors {
allowed_foreign_origins: AllowedForeignOrigins::None,
allowed_origins: AllowedOrigins::None,
})
}
}
12 changes: 6 additions & 6 deletions cnd/src/http_api/route_factory.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use crate::{
config::AllowedForeignOrigins,
config::AllowedOrigins,
db::SaveRfc003Messages,
http_api,
network::{Network, SendRequest},
Expand Down Expand Up @@ -35,7 +35,7 @@ pub fn create<
>(
peer_id: PeerId,
dependencies: D,
allowed_foreign_origins: AllowedForeignOrigins,
allowed_origins: &AllowedOrigins,
) -> BoxedFilter<(impl Reply,)> {
let swaps = warp::path(http_api::PATH);
let rfc003 = swaps.and(warp::path(RFC003));
Expand All @@ -46,10 +46,10 @@ pub fn create<
let cors = warp::cors()
.allow_methods(vec!["GET", "POST"])
.allow_header("content-type");
let cors = match allowed_foreign_origins {
AllowedForeignOrigins::None => cors.allow_origins(Vec::<&str>::new()),
AllowedForeignOrigins::All => cors.allow_any_origin(),
AllowedForeignOrigins::List(hosts) => {
let cors = match allowed_origins {
AllowedOrigins::None => cors.allow_origins(Vec::<&str>::new()),
AllowedOrigins::All => cors.allow_any_origin(),
AllowedOrigins::Some(hosts) => {
cors.allow_origins::<Vec<&str>>(hosts.iter().map(|host| host.as_str()).collect())
}
};
Expand Down
2 changes: 1 addition & 1 deletion cnd/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@ fn spawn_warp_instance<
let routes = route_factory::create(
peer_id,
dependencies,
settings.http_api.cors.allowed_foreign_origins.clone(),
&settings.http_api.cors.allowed_origins,
);

let listen_addr = SocketAddr::new(
Expand Down

0 comments on commit 5d6023d

Please sign in to comment.