Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

support PIVOT table syntax #836

Merged
merged 1 commit into from
Mar 26, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
42 changes: 42 additions & 0 deletions src/ast/query.rs
Original file line number Diff line number Diff line change
Expand Up @@ -670,6 +670,18 @@ pub enum TableFactor {
table_with_joins: Box<TableWithJoins>,
alias: Option<TableAlias>,
},
/// Represents PIVOT operation on a table.
/// For example `FROM monthly_sales PIVOT(sum(amount) FOR MONTH IN ('JAN', 'FEB'))`
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you add a link to the Snowflake docs (https://docs.snowflake.com/en/sql-reference/constructs/pivot) in here?

/// See <https://docs.snowflake.com/en/sql-reference/constructs/pivot>
Pivot {
#[cfg_attr(feature = "visitor", visit(with = "visit_relation"))]
name: ObjectName,
table_alias: Option<TableAlias>,
aggregate_function: Expr, // Function expression
value_column: Vec<Ident>,
pivot_values: Vec<Value>,
pivot_alias: Option<TableAlias>,
},
}

impl fmt::Display for TableFactor {
Expand Down Expand Up @@ -742,6 +754,36 @@ impl fmt::Display for TableFactor {
}
Ok(())
}
TableFactor::Pivot {
name,
table_alias,
aggregate_function,
value_column,
pivot_values,
pivot_alias,
} => {
write!(f, "{}", name)?;
if table_alias.is_some() {
write!(f, " AS {}", table_alias.as_ref().unwrap())?;
}
write!(
f,
" PIVOT({} FOR {} IN (",
aggregate_function,
Expr::CompoundIdentifier(value_column.to_vec())
)?;
for value in pivot_values {
write!(f, "{}", value)?;
if !value.eq(pivot_values.last().unwrap()) {
write!(f, ", ")?;
}
}
write!(f, "))")?;
if pivot_alias.is_some() {
write!(f, " AS {}", pivot_alias.as_ref().unwrap())?;
}
Ok(())
}
}
}
}
Expand Down
2 changes: 2 additions & 0 deletions src/keywords.rs
Original file line number Diff line number Diff line change
Expand Up @@ -431,6 +431,7 @@ define_keywords!(
PERCENTILE_DISC,
PERCENT_RANK,
PERIOD,
PIVOT,
PLACING,
PLANS,
PORTION,
Expand Down Expand Up @@ -647,6 +648,7 @@ pub const RESERVED_FOR_TABLE_ALIAS: &[Keyword] = &[
Keyword::SORT,
Keyword::HAVING,
Keyword::ORDER,
Keyword::PIVOT,
Keyword::TOP,
Keyword::LATERAL,
Keyword::VIEW,
Expand Down
40 changes: 40 additions & 0 deletions src/parser.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5672,6 +5672,9 @@ impl<'a> Parser<'a> {
| TableFactor::Table { alias, .. }
| TableFactor::UNNEST { alias, .. }
| TableFactor::TableFunction { alias, .. }
| TableFactor::Pivot {
pivot_alias: alias, ..
}
| TableFactor::NestedJoin { alias, .. } => {
// but not `FROM (mytable AS alias1) AS alias2`.
if let Some(inner_alias) = alias {
Expand Down Expand Up @@ -5729,13 +5732,21 @@ impl<'a> Parser<'a> {
})
} else {
let name = self.parse_object_name()?;

// Postgres, MSSQL: table-valued functions:
let args = if self.consume_token(&Token::LParen) {
Some(self.parse_optional_args()?)
} else {
None
};

let alias = self.parse_optional_table_alias(keywords::RESERVED_FOR_TABLE_ALIAS)?;

// Pivot
if self.parse_keyword(Keyword::PIVOT) {
return self.parse_pivot_table_factor(name, alias);
}

// MSSQL-specific table hints:
let mut with_hints = vec![];
if self.parse_keyword(Keyword::WITH) {
Expand Down Expand Up @@ -5773,6 +5784,35 @@ impl<'a> Parser<'a> {
})
}

pub fn parse_pivot_table_factor(
&mut self,
name: ObjectName,
table_alias: Option<TableAlias>,
) -> Result<TableFactor, ParserError> {
self.expect_token(&Token::LParen)?;
let function_name = match self.next_token().token {
Token::Word(w) => Ok(w.value),
_ => self.expected("an aggregate function name", self.peek_token()),
}?;
let function = self.parse_function(ObjectName(vec![Ident::new(function_name)]))?;
self.expect_keyword(Keyword::FOR)?;
let value_column = self.parse_object_name()?.0;
self.expect_keyword(Keyword::IN)?;
self.expect_token(&Token::LParen)?;
let pivot_values = self.parse_comma_separated(Parser::parse_value)?;
self.expect_token(&Token::RParen)?;
self.expect_token(&Token::RParen)?;
let alias = self.parse_optional_table_alias(keywords::RESERVED_FOR_TABLE_ALIAS)?;
Ok(TableFactor::Pivot {
name,
table_alias,
aggregate_function: function,
value_column,
pivot_values,
pivot_alias: alias,
})
}

pub fn parse_join_constraint(&mut self, natural: bool) -> Result<JoinConstraint, ParserError> {
if natural {
Ok(JoinConstraint::Natural)
Expand Down
62 changes: 62 additions & 0 deletions tests/sqlparser_common.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
use matches::assert_matches;

use sqlparser::ast::SelectItem::UnnamedExpr;
use sqlparser::ast::TableFactor::Pivot;
use sqlparser::ast::*;
use sqlparser::dialect::{
AnsiDialect, BigQueryDialect, ClickHouseDialect, GenericDialect, HiveDialect, MsSqlDialect,
Expand Down Expand Up @@ -6718,6 +6719,67 @@ fn parse_with_recursion_limit() {
assert!(matches!(res, Ok(_)), "{res:?}");
}

#[test]
fn parse_pivot_table() {
let sql = concat!(
"SELECT * FROM monthly_sales AS a ",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  • can you add a test case without the AS component?
  • can you add a roundtrip test (so we can verify the Display() implementation works too)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  • There is below a variable sql_without_table_alias that tests this.
  • Will add the roundtrip as it is missing.

"PIVOT(SUM(a.amount) FOR a.MONTH IN ('JAN', 'FEB', 'MAR', 'APR')) AS p (c, d) ",
"ORDER BY EMPID"
);

assert_eq!(
verified_only_select(sql).from[0].relation,
Pivot {
name: ObjectName(vec![Ident::new("monthly_sales")]),
table_alias: Some(TableAlias {
name: Ident::new("a"),
columns: vec![]
}),
aggregate_function: Expr::Function(Function {
name: ObjectName(vec![Ident::new("SUM")]),
args: (vec![FunctionArg::Unnamed(FunctionArgExpr::Expr(
Expr::CompoundIdentifier(vec![Ident::new("a"), Ident::new("amount"),])
))]),
over: None,
distinct: false,
special: false,
}),
value_column: vec![Ident::new("a"), Ident::new("MONTH")],
pivot_values: vec![
Value::SingleQuotedString("JAN".to_string()),
Value::SingleQuotedString("FEB".to_string()),
Value::SingleQuotedString("MAR".to_string()),
Value::SingleQuotedString("APR".to_string()),
],
pivot_alias: Some(TableAlias {
name: Ident {
value: "p".to_string(),
quote_style: None
},
columns: vec![Ident::new("c"), Ident::new("d")],
}),
}
);
assert_eq!(verified_stmt(sql).to_string(), sql);

let sql_without_table_alias = concat!(
"SELECT * FROM monthly_sales ",
"PIVOT(SUM(a.amount) FOR a.MONTH IN ('JAN', 'FEB', 'MAR', 'APR')) AS p (c, d) ",
"ORDER BY EMPID"
);
assert_matches!(
verified_only_select(sql_without_table_alias).from[0].relation,
Pivot {
table_alias: None, // parsing should succeed with empty alias
..
}
);
assert_eq!(
verified_stmt(sql_without_table_alias).to_string(),
sql_without_table_alias
);
}

/// Makes a predicate that looks like ((user_id = $id) OR user_id = $2...)
fn make_where_clause(num: usize) -> String {
use std::fmt::Write;
Expand Down