Skip to content

Commit

Permalink
support PIVOT table syntax
Browse files Browse the repository at this point in the history
Signed-off-by: Pawel Leszczynski <leszczynski.pawel@gmail.com>
  • Loading branch information
pawel-big-lebowski committed Mar 14, 2023
1 parent 4ff3aeb commit 0fcafd9
Show file tree
Hide file tree
Showing 4 changed files with 146 additions and 0 deletions.
41 changes: 41 additions & 0 deletions src/ast/query.rs
Original file line number Diff line number Diff line change
Expand Up @@ -670,6 +670,17 @@ pub enum TableFactor {
table_with_joins: Box<TableWithJoins>,
alias: Option<TableAlias>,
},
/// Represents PIVOT operation on a table.
/// For example `FROM monthly_sales PIVOT(sum(amount) FOR MONTH IN ('JAN', 'FEB'))`
Pivot {
#[cfg_attr(feature = "visitor", visit(with = "visit_relation"))]
name: ObjectName,
table_alias: Option<TableAlias>,
aggregate_function: Expr, // Function expression
value_column: Vec<Ident>,
pivot_values: Vec<Value>,
pivot_alias: Option<TableAlias>,
},
}

impl fmt::Display for TableFactor {
Expand Down Expand Up @@ -742,6 +753,36 @@ impl fmt::Display for TableFactor {
}
Ok(())
}
TableFactor::Pivot {
name,
table_alias,
aggregate_function,
value_column,
pivot_values,
pivot_alias,
} => {
write!(f, "{}", name)?;
if table_alias.is_some() {
write!(f, " AS {}", table_alias.as_ref().unwrap())?;
}
write!(
f,
" PIVOT({} FOR {} IN (",
aggregate_function,
Expr::CompoundIdentifier(value_column.to_vec())
)?;
for value in pivot_values {
write!(f, "{}", value)?;
if !value.eq(pivot_values.last().unwrap()) {
write!(f, ", ")?;
}
}
write!(f, "))")?;
if pivot_alias.is_some() {
write!(f, " AS {}", pivot_alias.as_ref().unwrap())?;
}
Ok(())
}
}
}
}
Expand Down
2 changes: 2 additions & 0 deletions src/keywords.rs
Original file line number Diff line number Diff line change
Expand Up @@ -431,6 +431,7 @@ define_keywords!(
PERCENTILE_DISC,
PERCENT_RANK,
PERIOD,
PIVOT,
PLACING,
PLANS,
PORTION,
Expand Down Expand Up @@ -647,6 +648,7 @@ pub const RESERVED_FOR_TABLE_ALIAS: &[Keyword] = &[
Keyword::SORT,
Keyword::HAVING,
Keyword::ORDER,
Keyword::PIVOT,
Keyword::TOP,
Keyword::LATERAL,
Keyword::VIEW,
Expand Down
46 changes: 46 additions & 0 deletions src/parser.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5672,6 +5672,9 @@ impl<'a> Parser<'a> {
| TableFactor::Table { alias, .. }
| TableFactor::UNNEST { alias, .. }
| TableFactor::TableFunction { alias, .. }
| TableFactor::Pivot {
pivot_alias: alias, ..
}
| TableFactor::NestedJoin { alias, .. } => {
// but not `FROM (mytable AS alias1) AS alias2`.
if let Some(inner_alias) = alias {
Expand Down Expand Up @@ -5729,13 +5732,21 @@ impl<'a> Parser<'a> {
})
} else {
let name = self.parse_object_name()?;

// Postgres, MSSQL: table-valued functions:
let args = if self.consume_token(&Token::LParen) {
Some(self.parse_optional_args()?)
} else {
None
};

let alias = self.parse_optional_table_alias(keywords::RESERVED_FOR_TABLE_ALIAS)?;

// Pivot
if self.parse_keyword(Keyword::PIVOT) {
return self.parse_pivot_table_factor(name, alias);
}

// MSSQL-specific table hints:
let mut with_hints = vec![];
if self.parse_keyword(Keyword::WITH) {
Expand Down Expand Up @@ -5773,6 +5784,41 @@ impl<'a> Parser<'a> {
})
}

pub fn parse_pivot_table_factor(
&mut self,
name: ObjectName,
table_alias: Option<TableAlias>,
) -> Result<TableFactor, ParserError> {
self.expect_token(&Token::LParen)?;
let function_name = match self.next_token().token {
Token::Word(w) => Ok(w.value),
_ => self.expected("a function name", self.peek_token()),
}?;
let function = self
.parse_function(ObjectName(vec![Ident::new(function_name)]))
.unwrap();
self.expect_keyword(Keyword::FOR)?;
let mut value_column = vec![self.parse_identifier()?];
while self.next_token().token.eq(&Token::Period) {
value_column.push(self.parse_identifier()?);
}
self.prev_token(); // not a period
self.expect_keyword(Keyword::IN)?;
self.expect_token(&Token::LParen)?;
let pivot_values = self.parse_comma_separated(Parser::parse_value)?;
self.expect_token(&Token::RParen)?;
self.expect_token(&Token::RParen)?;
let alias = self.parse_optional_table_alias(keywords::RESERVED_FOR_TABLE_ALIAS)?;
Ok(TableFactor::Pivot {
name,
table_alias,
aggregate_function: function,
value_column,
pivot_values,
pivot_alias: alias,
})
}

pub fn parse_join_constraint(&mut self, natural: bool) -> Result<JoinConstraint, ParserError> {
if natural {
Ok(JoinConstraint::Natural)
Expand Down
57 changes: 57 additions & 0 deletions tests/sqlparser_common.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
use matches::assert_matches;

use sqlparser::ast::SelectItem::UnnamedExpr;
use sqlparser::ast::TableFactor::Pivot;
use sqlparser::ast::*;
use sqlparser::dialect::{
AnsiDialect, BigQueryDialect, ClickHouseDialect, GenericDialect, HiveDialect, MsSqlDialect,
Expand Down Expand Up @@ -6718,6 +6719,62 @@ fn parse_with_recursion_limit() {
assert!(matches!(res, Ok(_)), "{res:?}");
}

#[test]
fn parse_pivot_table() {
let sql = concat!(
"SELECT * FROM monthly_sales AS a ",
"PIVOT(SUM(a.amount) FOR a.MONTH IN ('JAN', 'FEB', 'MAR', 'APR')) AS p (c, d) ",
"ORDER BY EMPID"
);

assert_eq!(
verified_only_select(sql).from[0].relation,
Pivot {
name: ObjectName(vec![Ident::new("monthly_sales")]),
table_alias: Some(TableAlias {
name: Ident::new("a"),
columns: vec![]
}),
aggregate_function: Expr::Function(Function {
name: ObjectName(vec![Ident::new("SUM")]),
args: (vec![FunctionArg::Unnamed(FunctionArgExpr::Expr(
Expr::CompoundIdentifier(vec![Ident::new("a"), Ident::new("amount"),])
))]),
over: None,
distinct: false,
special: false,
}),
value_column: vec![Ident::new("a"), Ident::new("MONTH")],
pivot_values: vec![
Value::SingleQuotedString("JAN".to_string()),
Value::SingleQuotedString("FEB".to_string()),
Value::SingleQuotedString("MAR".to_string()),
Value::SingleQuotedString("APR".to_string()),
],
pivot_alias: Some(TableAlias {
name: Ident {
value: "p".to_string(),
quote_style: None
},
columns: vec![Ident::new("c"), Ident::new("d")],
}),
}
);

let sql_without_table_alias = concat!(
"SELECT * FROM monthly_sales ",
"PIVOT(SUM(a.amount) FOR a.MONTH IN ('JAN', 'FEB', 'MAR', 'APR')) AS p (c, d) ",
"ORDER BY EMPID"
);
assert_matches!(
verified_only_select(sql_without_table_alias).from[0].relation,
Pivot {
table_alias: None, // parsing should succeed with empty alias
..
}
);
}

/// Makes a predicate that looks like ((user_id = $id) OR user_id = $2...)
fn make_where_clause(num: usize) -> String {
use std::fmt::Write;
Expand Down

0 comments on commit 0fcafd9

Please sign in to comment.