Skip to content

Commit

Permalink
extend registry capability & integrate into conformance server
Browse files Browse the repository at this point in the history
Summary:
- extend the any registry with an index on hash prefix
- wire up registry initialization in the conformance server

Reviewed By: dtolnay

Differential Revision: D57496851

fbshipit-source-id: 2c7d9779ef608cd8617b1984fa5a309b14b33d0c
  • Loading branch information
Shayne Fletcher authored and facebook-github-bot committed May 18, 2024
1 parent ea08787 commit 9d21e77
Show file tree
Hide file tree
Showing 3 changed files with 85 additions and 12 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -37,11 +37,6 @@ extern crate serde;

{{/program:serde?}}{{!
}}{{#program:any?}}
use fbthrift_conformance as _; // Work in progress. Not used yet but here to remind us of the link requirements when doing so.

{{/program:any?}}{{!
}}{{#program:types_include_srcs}}
include!("{{.}}");
{{/program:types_include_srcs}}{{!
Expand Down Expand Up @@ -102,6 +97,14 @@ mod dot_dot {{>lib/block}}{{!
}{{!
}}{{/program:nonexhaustiveStructs?}}

{{#program:any?}}
pub fn init_registry(registry: &mut fbthrift_conformance::AnyRegistry) -> anyhow::Result<bool> {
Ok(true)
}

{{/program:any?}}{{!
}}
pub(crate) mod r#impl {{>lib/block}}{{>lib/mod.impl}}}
{{!
}}{{#program:has_default_tests?}}{{>lib/adapter/default_test}}{{/program:has_default_tests?}}
Original file line number Diff line number Diff line change
Expand Up @@ -23,11 +23,14 @@ use clap::Parser;
use conformance::services::conformance_service::PatchExn;
use conformance::services::conformance_service::RoundTripExn;
use conformance_services::ConformanceService;
use enum_ as enum_types; // fbcode//thrift/test/testet:enum
use futures::StreamExt;
use patch_data::PatchOpRequest;
use patch_data::PatchOpResponse;
use serialization::RoundTripRequest;
use serialization::RoundTripResponse;
use testset as test_types; // fbcode//thrift/test/testet:testset
use tracing::info;
use tracing_subscriber::layer::SubscriberExt;

#[derive(Debug, Parser)]
Expand All @@ -45,9 +48,23 @@ fn main(fb: fbinit::FacebookInit) -> Result<()> {

init_logging(args.log);

let any_registry = Box::leak(Box::new(fbthrift_conformance::AnyRegistry::new()));
test_types::init_registry(any_registry)?;
enum_types::init_registry(any_registry)?;
info!(
"\"Any registry\" initialized, {} types registered",
any_registry.num_registered_types()
);

let runtime = tokio::runtime::Runtime::new()?;
let service = move |proto| {
conformance_services::make_ConformanceService_server(proto, ConformanceServiceImpl { fb })
let service = {
let any_registry = any_registry as &'static _;
move |proto| {
conformance_services::make_ConformanceService_server(
proto,
ConformanceServiceImpl { fb, any_registry },
)
}
};
let thrift_server = srserver::ThriftServerBuilder::new(fb)
.with_port(args.port)
Expand Down Expand Up @@ -102,6 +119,7 @@ fn init_logging(directives: Vec<tracing_subscriber::filter::Directive>) {
#[derive(Clone)]
pub struct ConformanceServiceImpl {
pub fb: fbinit::FacebookInit,
pub any_registry: &'static fbthrift_conformance::AnyRegistry,
}

#[async_trait]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ pub struct AnySerDeser {
pub struct AnyRegistry {
uri_to_typeid: HashMap<&'static str, TypeId>,
typeid_to_uri: HashMap<TypeId, &'static str>,
hash_prefix_to_typeid: HashMap<Vec<u8>, TypeId>,
alg_to_hashes: HashMap<UniversalHashAlgorithm, HashSet<Vec<u8>>>,
typeid_to_serializers: HashMap<TypeId, AnySerDeser>,
}
Expand All @@ -63,6 +64,7 @@ impl AnyRegistry {
Self {
uri_to_typeid: HashMap::new(),
typeid_to_uri: HashMap::new(),
hash_prefix_to_typeid: HashMap::new(),
typeid_to_serializers: HashMap::new(),
alg_to_hashes,
}
Expand All @@ -72,12 +74,19 @@ impl AnyRegistry {
&mut self,
) -> Result<bool> {
let uri = T::uri();
let hash_prefix =
get_universal_hash_prefix_sha_256(uri, UNIVERSAL_HASH_PREFIX_SHA_256_LEN)?;
let type_id = TypeId::of::<T>();
if self.uri_to_typeid.contains_key(uri) || self.typeid_to_uri.contains_key(&type_id) {

if self.uri_to_typeid.contains_key(uri)
|| self.hash_prefix_to_typeid.contains_key(&hash_prefix)
{
return Ok(false);
}
self.uri_to_typeid.insert(uri, type_id);
self.typeid_to_uri.insert(type_id, uri);
self.hash_prefix_to_typeid.insert(hash_prefix, type_id);

for (alg, hashes) in self.alg_to_hashes.iter_mut() {
let hash = get_universal_hash(*alg, uri)?;
hashes.insert(hash);
Expand All @@ -99,6 +108,10 @@ impl AnyRegistry {
Ok(true)
}

pub fn num_registered_types(&self) -> usize {
self.uri_to_typeid.len()
}

pub fn has_type<T: 'static + GetUri>(&self, obj: &any::Any) -> Result<bool> {
let type_uri = T::uri();
let type_hash_prefix_sha2_256 =
Expand Down Expand Up @@ -152,7 +165,7 @@ impl AnyRegistry {
deserialize(&obj.data, obj.protocol.unwrap_or(StandardProtocol::Compact))
}

pub fn serializers(&self, uri: &str) -> Result<&AnySerDeser> {
pub fn serializers_given_uri(&self, uri: &str) -> Result<&AnySerDeser> {
self.typeid_to_serializers
.get(
self.uri_to_typeid
Expand All @@ -161,6 +174,16 @@ impl AnyRegistry {
)
.context("serializers lookup failure")
}

pub fn serializers_given_hash_prefix(&self, hash_prefix: &Vec<u8>) -> Result<&AnySerDeser> {
self.typeid_to_serializers
.get(
self.hash_prefix_to_typeid
.get(hash_prefix)
.context("typeid lookup failure")?,
)
.context("serializers lookup failure")
}
}

fn serialize<T: SerializeRef>(obj: &T, protocol: StandardProtocol) -> Result<Vec<u8>> {
Expand Down Expand Up @@ -259,14 +282,43 @@ mod tests {
}

#[test]
fn test_round_trip_through_any() -> Result<()> {
fn test_round_trip_through_any_via_uri() -> Result<()> {
let mut any_registry = AnyRegistry::new();
any_registry.register_type::<struct_map_string_i32>()?;

let uri = struct_map_string_i32::uri();
let AnySerDeser {
serialize,
deserialize,
} = any_registry.serializers_given_uri(uri)?;

let obj = get_test_object();
for protocol in get_test_protocols() {
let any: Box<dyn std::any::Any> = Box::new(obj.clone());
let bytes = serialize(any, protocol)?;
assert!(!bytes.is_empty());
let val = *(deserialize(&bytes, protocol)?
.downcast::<struct_map_string_i32>()
.map_err(|_| anyhow::Error::msg("bad any cast")))?;
assert_eq!(val, obj);
}

Ok(())
}

#[test]
fn test_round_trip_through_any_via_hash_prefix() -> Result<()> {
let mut any_registry = AnyRegistry::new();
any_registry.register_type::<struct_map_string_i32>()?;

let hash_prefix = get_universal_hash_prefix_sha_256(
struct_map_string_i32::uri(),
UNIVERSAL_HASH_PREFIX_SHA_256_LEN,
)?;
let AnySerDeser {
serialize,
deserialize,
} = any_registry.serializers(struct_map_string_i32::uri())?;
} = any_registry.serializers_given_hash_prefix(&hash_prefix)?;

let obj = get_test_object();
for protocol in get_test_protocols() {
Expand All @@ -275,7 +327,7 @@ mod tests {
assert!(!bytes.is_empty());
let val = *(deserialize(&bytes, protocol)?
.downcast::<struct_map_string_i32>()
.map_err(|_| anyhow::Error::msg("cast failure")))?;
.map_err(|_| anyhow::Error::msg("bad any cast")))?;
assert_eq!(val, obj);
}

Expand Down

0 comments on commit 9d21e77

Please sign in to comment.