Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: TraceState cannot insert new key-value pairs. #567

Merged
merged 2 commits into from
Jun 9, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
39 changes: 25 additions & 14 deletions opentelemetry-aws/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,11 @@
#[cfg(feature = "trace")]
pub mod trace {
use opentelemetry::{
global::{self, Error},
propagation::{text_map_propagator::FieldIter, Extractor, Injector, TextMapPropagator},
trace::{SpanContext, SpanId, TraceContextExt, TraceFlags, TraceId, TraceState},
trace::{
SpanContext, SpanId, TraceContextExt, TraceError, TraceFlags, TraceId, TraceState,
},
Context,
};
use std::convert::{TryFrom, TryInto};
Expand Down Expand Up @@ -125,21 +128,29 @@ pub mod trace {
}
}

let trace_state: TraceState = TraceState::from_key_value(kv_vec)?;

if trace_id.to_u128() == 0 {
return Err(());
}
match TraceState::from_key_value(kv_vec) {
Ok(trace_state) => {
if trace_id.to_u128() == 0 {
return Err(());
}

let context: SpanContext = SpanContext::new(
trace_id,
parent_segment_id,
sampling_decision,
true,
trace_state,
);
let context: SpanContext = SpanContext::new(
trace_id,
parent_segment_id,
sampling_decision,
true,
trace_state,
);

Ok(context)
Ok(context)
}
Err(trace_state_err) => {
global::handle_error(Error::Trace(TraceError::Other(Box::new(
trace_state_err,
))));
Err(()) //todo: assign an error type instead of using ()
}
}
}
}

Expand Down
15 changes: 13 additions & 2 deletions opentelemetry-jaeger/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -200,8 +200,11 @@ mod propagator {
//!
//! [`Jaeger documentation`]: https://www.jaegertracing.io/docs/1.18/client-libraries/#propagation-format
use opentelemetry::{
global::{self, Error},
propagation::{text_map_propagator::FieldIter, Extractor, Injector, TextMapPropagator},
trace::{SpanContext, SpanId, TraceContextExt, TraceFlags, TraceId, TraceState},
trace::{
SpanContext, SpanId, TraceContextExt, TraceError, TraceFlags, TraceId, TraceState,
},
Context,
};
use std::borrow::Cow;
Expand Down Expand Up @@ -324,7 +327,15 @@ mod propagator {
.map(|value| (key.to_string(), value.to_string()))
});

TraceState::from_key_value(uber_context_keys)
match TraceState::from_key_value(uber_context_keys) {
Ok(trace_state) => Ok(trace_state),
Err(trace_state_err) => {
global::handle_error(Error::Trace(TraceError::Other(Box::new(
trace_state_err,
))));
Err(()) //todo: assign an error type instead of using ()
}
}
}
}

Expand Down
2 changes: 1 addition & 1 deletion opentelemetry/src/sdk/trace/tracer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -361,7 +361,7 @@ mod tests {
SamplingResult {
decision: SamplingDecision::RecordAndSample,
attributes: Vec::new(),
trace_state: trace_state.insert("foo".into(), "notbar".into()).unwrap(),
trace_state: trace_state.insert("foo", "notbar").unwrap(),
}
}
}
Expand Down
2 changes: 1 addition & 1 deletion opentelemetry/src/trace/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -189,7 +189,7 @@ pub use self::{
noop::{NoopSpan, NoopSpanExporter, NoopTracer, NoopTracerProvider},
provider::TracerProvider,
span::{Span, SpanKind, StatusCode},
span_context::{SpanContext, SpanId, TraceFlags, TraceId, TraceState},
span_context::{SpanContext, SpanId, TraceFlags, TraceId, TraceState, TraceStateError},
tracer::{SpanBuilder, Tracer},
};
use crate::sdk::export::ExportError;
Expand Down
93 changes: 65 additions & 28 deletions opentelemetry/src/trace/span_context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ use std::collections::VecDeque;
use std::fmt;
use std::ops::{BitAnd, BitOr, Not};
use std::str::FromStr;
use thiserror::Error;

/// Flags that can be set on a [`SpanContext`].
///
Expand Down Expand Up @@ -238,16 +239,15 @@ impl TraceState {
/// # Examples
///
/// ```
/// use opentelemetry::trace::TraceState;
/// use opentelemetry::trace::{TraceState, TraceStateError};
///
/// let kvs = vec![("foo", "bar"), ("apple", "banana")];
/// let trace_state: Result<TraceState, ()> = TraceState::from_key_value(kvs);
/// let trace_state: Result<TraceState, TraceStateError> = TraceState::from_key_value(kvs);
///
/// assert!(trace_state.is_ok());
/// assert_eq!(trace_state.unwrap().header(), String::from("foo=bar,apple=banana"))
/// ```
#[allow(clippy::all)]
pub fn from_key_value<T, K, V>(trace_state: T) -> Result<Self, ()>
pub fn from_key_value<T, K, V>(trace_state: T) -> Result<Self, TraceStateError>
where
T: IntoIterator<Item = (K, V)>,
K: ToString,
Expand All @@ -257,14 +257,16 @@ impl TraceState {
.into_iter()
.map(|(key, value)| {
let (key, value) = (key.to_string(), value.to_string());
if !TraceState::valid_key(key.as_str()) || !TraceState::valid_value(value.as_str())
{
return Err(());
if !TraceState::valid_key(key.as_str()) {
return Err(TraceStateError::InvalidKey(key));
}
if !TraceState::valid_value(value.as_str()) {
return Err(TraceStateError::InvalidValue(value));
}

Ok((key, value))
})
.collect::<Result<VecDeque<_>, ()>>()?;
.collect::<Result<VecDeque<_>, TraceStateError>>()?;

if ordered_data.is_empty() {
Ok(TraceState(None))
Expand Down Expand Up @@ -292,13 +294,20 @@ impl TraceState {
/// updated key/value is returned.
///
/// ['spec']: https://www.w3.org/TR/trace-context/#list
#[allow(clippy::all)]
pub fn insert(&self, key: String, value: String) -> Result<TraceState, ()> {
if !TraceState::valid_key(key.as_str()) || !TraceState::valid_value(value.as_str()) {
return Err(());
pub fn insert<K, V>(&self, key: K, value: V) -> Result<TraceState, TraceStateError>
where
K: Into<String>,
V: Into<String>,
{
let (key, value) = (key.into(), value.into());
if !TraceState::valid_key(key.as_str()) {
return Err(TraceStateError::InvalidKey(key));
}
if !TraceState::valid_value(value.as_str()) {
return Err(TraceStateError::InvalidValue(value));
}

let mut trace_state = self.delete(key.clone())?;
let mut trace_state = self.delete_from_deque(key.clone());
let kvs = trace_state.0.get_or_insert(VecDeque::with_capacity(1));

kvs.push_front((key, value));
Expand All @@ -307,26 +316,29 @@ impl TraceState {
}

/// Removes the given key-value pair from the `TraceState`. If the key is invalid per the
/// [W3 Spec]['spec'] or the key does not exist an `Err` is returned. Else, a new `TraceState`
/// [W3 Spec]['spec'] an `Err` is returned. Else, a new `TraceState`
/// with the removed entry is returned.
///
/// If the key is not in `TraceState`. The original `TraceState` will be cloned and returned.
/// ['spec']: https://www.w3.org/TR/trace-context/#list
#[allow(clippy::all)]
pub fn delete(&self, key: String) -> Result<TraceState, ()> {
pub fn delete<K: Into<String>>(&self, key: K) -> Result<TraceState, TraceStateError> {
let key = key.into();
if !TraceState::valid_key(key.as_str()) {
return Err(());
return Err(TraceStateError::InvalidKey(key));
}

let mut owned = self.clone();
let kvs = owned.0.as_mut().ok_or(())?;
Ok(self.delete_from_deque(key))
}

if let Some(index) = kvs.iter().position(|x| *x.0 == *key) {
kvs.remove(index);
} else {
return Err(());
/// Delete key from trace state's deque. The key MUST be valid
fn delete_from_deque(&self, key: String) -> TraceState {
let mut owned = self.clone();
if let Some(kvs) = owned.0.as_mut() {
if let Some(index) = kvs.iter().position(|x| *x.0 == *key) {
kvs.remove(index);
}
}

Ok(owned)
owned
}

/// Creates a new `TraceState` header string, delimiting each key and value with a `=` and each
Expand All @@ -350,15 +362,15 @@ impl TraceState {
}

impl FromStr for TraceState {
type Err = ();
type Err = TraceStateError;

fn from_str(s: &str) -> Result<Self, Self::Err> {
let list_members: Vec<&str> = s.split_terminator(',').collect();
let mut key_value_pairs: Vec<(String, String)> = Vec::with_capacity(list_members.len());

for list_member in list_members {
match list_member.find('=') {
None => return Err(()),
None => return Err(TraceStateError::InvalidList(list_member.to_string())),
Some(separator_index) => {
let (key, value) = list_member.split_at(separator_index);
key_value_pairs
Expand All @@ -371,6 +383,23 @@ impl FromStr for TraceState {
}
}

/// Error returned by `TraceState` operations.
#[derive(Error, Debug)]
#[non_exhaustive]
pub enum TraceStateError {
/// The key is invalid. See https://www.w3.org/TR/trace-context/#key for requirement for keys.
#[error("{0} is not a valid key in TraceState, see https://www.w3.org/TR/trace-context/#key for more details")]
InvalidKey(String),

/// The value is invalid. See https://www.w3.org/TR/trace-context/#value for requirement for values.
#[error("{0} is not a valid value in TraceState, see https://www.w3.org/TR/trace-context/#value for more details")]
InvalidValue(String),

/// The value is invalid. See https://www.w3.org/TR/trace-context/#list for requirement for list members.
#[error("{0} is not a valid list member in TraceState, see https://www.w3.org/TR/trace-context/#list for more details")]
InvalidList(String),
}

/// Immutable portion of a `Span` which can be serialized and propagated.
///
/// Spans that do not have the `sampled` flag set in their [`TraceFlags`] will
Expand Down Expand Up @@ -514,7 +543,7 @@ mod tests {

let new_key = format!("{}-{}", test_case.0.get(test_case.2).unwrap(), "test");

let updated_trace_state = test_case.0.insert(test_case.2.into(), new_key.clone());
let updated_trace_state = test_case.0.insert(test_case.2, new_key.clone());
assert!(updated_trace_state.is_ok());
let updated_trace_state = updated_trace_state.unwrap();

Expand All @@ -533,4 +562,12 @@ mod tests {
assert!(deleted_trace_state.get(test_case.2).is_none());
}
}

#[test]
fn test_trace_state_insert() {
let trace_state = TraceState::from_key_value(vec![("foo", "bar")]).unwrap();
let inserted_trace_state = trace_state.insert("testkey", "testvalue").unwrap();
assert!(trace_state.get("testkey").is_none()); // The original state doesn't change
assert_eq!(inserted_trace_state.get("testkey").unwrap(), "testvalue"); //
}
}