Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature: RDS Proxy #7329

Merged
merged 5 commits into from
Feb 16, 2024
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
32 changes: 32 additions & 0 deletions moto/rds/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,3 +218,35 @@ def __init__(self, instance_engine: str, cluster_engine: str) -> None:
f"The engine name requested for your DB instance ({instance_engine}) doesn't match "
f"the engine name of your DB cluster ({cluster_engine})."
)


class InvalidSubnet(RDSClientError):
def __init__(self, subnet_identifier: str):
super().__init__(
"InvalidSubnet",
f"The requested subnet {subnet_identifier} is invalid, or multiple subnets were requested that are not all in a common VPC.",
)


class DBProxyAlreadyExistsFault(RDSClientError):
def __init__(self, db_proxy_identifier: str):
super().__init__(
"DBProxyAlreadyExistsFault",
f"Cannot create the DBProxy because a DBProxy with the identifier {db_proxy_identifier} already exists.",
)


class DBProxyQuotaExceededFault(RDSClientError):
def __init__(self) -> None:
super().__init__(
"DBProxyQuotaExceeded",
"The request cannot be processed because it would exceed the maximum number of DBProxies.",
)


class DBProxyNotFoundFault(RDSClientError):
def __init__(self, db_proxy_identifier: str):
super().__init__(
"DBProxyNotFoundFault",
f"The specified proxy name {db_proxy_identifier} doesn't correspond to a proxy owned by your Amazon Web Services account in the specified Amazon Web Services Region.",
)
188 changes: 187 additions & 1 deletion moto/rds/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,9 @@
DBClusterToBeDeletedHasActiveMembers,
DBInstanceNotFoundError,
DBParameterGroupNotFoundError,
DBProxyAlreadyExistsFault,
DBProxyNotFoundFault,
DBProxyQuotaExceededFault,
DBSecurityGroupNotFoundError,
DBSnapshotAlreadyExistsError,
DBSnapshotNotFoundError,
Expand All @@ -39,6 +42,7 @@
InvalidGlobalClusterStateFault,
InvalidParameterCombination,
InvalidParameterValue,
InvalidSubnet,
OptionGroupNotFoundFaultError,
RDSClientError,
SnapshotQuotaExceededError,
Expand Down Expand Up @@ -1571,11 +1575,131 @@ def delete(self, account_id: str, region_name: str) -> None:
backend.delete_subnet_group(self.subnet_name)


class DBProxy(BaseModel):
def __init__(
self,
DBProxyName: str,
JoshLevyMN marked this conversation as resolved.
Show resolved Hide resolved
EngineFamily: str,
Auth: List[Dict[str, str]],
RoleArn: str,
VpcSubnetIds: List[str],
region_name: str,
account_id: str,
VpcSecurityGroupIds: Optional[List[str]],
RequireTLS: Optional[bool] = False,
IdleClientTimeout: Optional[int] = 1800,
DebugLogging: Optional[bool] = False,
tags: Optional[List[Dict[str, str]]] = None,
):
self.DBProxyName = DBProxyName
self.EngineFamily = EngineFamily
if self.EngineFamily not in ["MYSQL", "POSTGRESQ", "SQLSERVER"]:
raise InvalidParameterValue("Provided EngineFamily is not valid.")
self.Auth = Auth
self.RoleArn = RoleArn
self.VpcSubnetIds = VpcSubnetIds
self.VpcSecurityGroupIds = VpcSecurityGroupIds
self.RequireTLS = RequireTLS
if IdleClientTimeout is None:
self.IdleClientTimeout = 1800
else:
if int(IdleClientTimeout) < 1:
self.IdleClientTimeout = 1
elif int(IdleClientTimeout) > 28800:
self.IdleClientTimeout = 28800
else:
self.IdleClientTimeout = IdleClientTimeout
self.DebugLogging = DebugLogging
self.CreatedDate = iso_8601_datetime_with_milliseconds()
self.UpdatedDate = iso_8601_datetime_with_milliseconds()
if tags is None:
self.tags = []
else:
self.tags = tags
self.region_name = region_name
self.account_id = account_id
self.DBProxyARN = f"arn:aws:rds:{self.region_name}:{self.account_id}:db-proxy:{self.DBProxyName}"
self.arn = self.DBProxyARN
ec2_backend = ec2_backends[self.account_id][self.region_name]
try:
subnets = ec2_backend.describe_subnets(subnet_ids=self.VpcSubnetIds)
except Exception as e:
raise e
JoshLevyMN marked this conversation as resolved.
Show resolved Hide resolved
vpcs = []
for subnet in subnets:
vpcs.append(subnet.vpc_id)
if subnet.vpc_id != vpcs[0]:
raise InvalidSubnet(subnet_identifier=subnet.id)

self.VpcId = ec2_backend.describe_subnets(subnet_ids=[self.VpcSubnetIds[0]])[
0
].vpc_id
self.Status = "availible"
self.url_identifier = "".join(
random.choice(string.ascii_lowercase + string.digits) for _ in range(12)
)
self.Endpoint = f"{self.DBProxyName}.db-proxy-{self.url_identifier}.{self.region_name}.rds.amazonaws.com"

def get_tags(self) -> List[Dict[str, str]]:
return self.tags

def add_tags(self, tags: List[Dict[str, str]]) -> List[Dict[str, str]]:
new_keys = [tag_set["Key"] for tag_set in tags]
self.tags = [tag_set for tag_set in self.tags if tag_set["Key"] not in new_keys]
self.tags.extend(tags)
return self.tags

def remove_tags(self, tag_keys: List[str]) -> None:
self.tags = [tag_set for tag_set in self.tags if tag_set["Key"] not in tag_keys]

def to_xml(self) -> str:
template = Template(
"""
<RequireTLS>{{ dbproxy.RequireTLS }}</RequireTLS>
<VpcSecurityGroupIds>
{% if dbproxy.VpcSecurityGroupIds %}
{% for vpcsecuritygroupid in dbproxy.VpcSecurityGroupIds %}
<member>{{ vpcsecuritygroupid }}</member>
{% endfor %}
{% endif %}
</VpcSecurityGroupIds>
<Auth>
{% for auth in dbproxy.Auth %}
<member>
<UserName>{{ auth["UserName"] }}</UserName>
<AuthScheme>{{ auth["AuthScheme"] }}</AuthScheme>
<SecretArn>{{ auth["SecretArn"] }}</SecretArn>
<IAMAuth>{{ auth["IAMAuth"] }}</IAMAuth>
<ClientPasswordAuthType>{{ auth["ClientPasswordAuthType"] }}</ClientPasswordAuthType>
</member>
{% endfor %}
</Auth>
<EngineFamily>{{ dbproxy.EngineFamily }}</EngineFamily>
<UpdatedDate>{{ dbproxy.UpdatedDate }}</UpdatedDate>
<DBProxyName>{{ dbproxy.DBProxyName }}</DBProxyName>
<IdleClientTimeout>{{ dbproxy.IdleClientTimeout }}</IdleClientTimeout>
<Endpoint>{{ dbproxy.Endpoint }}</Endpoint>
<CreatedDate>{{ dbproxy.CreatedDate }}</CreatedDate>
<RoleArn>{{ dbproxy.RoleArn }}</RoleArn>
<DebugLogging>{{ dbproxy.DebugLogging }}</DebugLogging>
<VpcId>{{ dbproxy.VpcId }}</VpcId>
<DBProxyArn>{{ dbproxy.DBProxyARN }}</DBProxyArn>
<VpcSubnetIds>
{% for vpcsubnetid in dbproxy.VpcSubnetIds %}
<member>{{ vpcsubnetid }}</member>
{% endfor %}
</VpcSubnetIds>
<Status>{{ dbproxy.Status }}</Status>
"""
)
return template.render(dbproxy=self)


class RDSBackend(BaseBackend):
def __init__(self, region_name: str, account_id: str):
super().__init__(region_name, account_id)
self.arn_regex = re_compile(
r"^arn:aws:rds:.*:[0-9]*:(db|cluster|es|og|pg|ri|secgrp|snapshot|cluster-snapshot|subgrp):.*$"
r"^arn:aws:rds:.*:[0-9]*:(db|cluster|es|og|pg|ri|secgrp|snapshot|cluster-snapshot|subgrp|db-proxy):.*$"
)
self.clusters: Dict[str, Cluster] = OrderedDict()
self.global_clusters: Dict[str, GlobalCluster] = OrderedDict()
Expand All @@ -1590,6 +1714,7 @@ def __init__(self, region_name: str, account_id: str):
self.security_groups: Dict[str, SecurityGroup] = {}
self.subnet_groups: Dict[str, SubnetGroup] = {}
self._db_cluster_options: Optional[List[Dict[str, Any]]] = None
self.db_proxies: Dict[str, DBProxy] = OrderedDict()

def reset(self) -> None:
self.neptune.reset()
Expand Down Expand Up @@ -2584,6 +2709,9 @@ def list_tags_for_resource(self, arn: str) -> List[Dict[str, str]]:
elif resource_type == "subgrp": # DB subnet group
if resource_name in self.subnet_groups:
return self.subnet_groups[resource_name].get_tags()
elif resource_type == "db-proxy": # DB Proxy
if resource_name in self.db_proxies:
return self.db_proxies[resource_name].get_tags()
else:
raise RDSClientError(
"InvalidParameterValue", f"Invalid resource name: {arn}"
Expand Down Expand Up @@ -2626,6 +2754,9 @@ def remove_tags_from_resource(self, arn: str, tag_keys: List[str]) -> None:
elif resource_type == "subgrp": # DB subnet group
if resource_name in self.subnet_groups:
self.subnet_groups[resource_name].remove_tags(tag_keys)
elif resource_type == "db-proxy": # DB Proxy
if resource_name in self.db_proxies:
self.db_proxies[resource_name].remove_tags(tag_keys)
else:
raise RDSClientError(
"InvalidParameterValue", f"Invalid resource name: {arn}"
Expand Down Expand Up @@ -2667,6 +2798,9 @@ def add_tags_to_resource(self, arn: str, tags: List[Dict[str, str]]) -> List[Dic
elif resource_type == "subgrp": # DB subnet group
if resource_name in self.subnet_groups:
return self.subnet_groups[resource_name].add_tags(tags)
elif resource_type == "db-proxy": # DB Proxy
if resource_name in self.db_proxies:
return self.db_proxies[resource_name].add_tags(tags)
else:
raise RDSClientError(
"InvalidParameterValue", f"Invalid resource name: {arn}"
Expand Down Expand Up @@ -2908,6 +3042,58 @@ def modify_db_cluster_snapshot_attribute(
)
return snapshot.attributes

def create_db_proxy(
self,
db_proxy_name: str,
engine_family: str,
auth: List[Dict[str, str]],
role_arn: str,
vpc_subnet_ids: List[str],
vpc_security_group_ids: Optional[List[str]],
require_tls: Optional[bool],
idle_client_timeout: Optional[int],
debug_logging: Optional[bool],
tags: Optional[List[Dict[str, str]]],
) -> DBProxy:
self._validate_db_identifier(db_proxy_name)
if db_proxy_name in self.db_proxies:
raise DBProxyAlreadyExistsFault(db_proxy_name)
if len(self.db_proxies) >= int(os.environ.get("MOTO_RDS_PROXY_LIMIT", "100")):
raise DBProxyQuotaExceededFault()
db_proxy = DBProxy(
db_proxy_name,
engine_family,
auth,
role_arn,
vpc_subnet_ids,
self.region_name,
self.account_id,
vpc_security_group_ids,
require_tls,
idle_client_timeout,
debug_logging,
tags,
)
self.db_proxies[db_proxy_name] = db_proxy
return db_proxy

def describe_db_proxies(
self,
db_proxy_name: Optional[str],
filters: Optional[
List[Dict[str, Any]]
] = None, # This parameter is not currently supported. https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/rds/client/describe_db_proxies.html
) -> List[DBProxy]:
# Filters: This parameter is not currently supported, so it is ignored
JoshLevyMN marked this conversation as resolved.
Show resolved Hide resolved
db_proxies = list(self.db_proxies.values())
print("type(): ", type(db_proxies))
JoshLevyMN marked this conversation as resolved.
Show resolved Hide resolved
if db_proxy_name and db_proxy_name in self.db_proxies.keys():
db_proxies = [self.db_proxies[db_proxy_name]]
print("type2(): ", type(db_proxies))
JoshLevyMN marked this conversation as resolved.
Show resolved Hide resolved
if db_proxy_name and db_proxy_name not in self.db_proxies.keys():
raise DBProxyNotFoundFault(db_proxy_name)
return db_proxies


class OptionGroup:
def __init__(
Expand Down
67 changes: 67 additions & 0 deletions moto/rds/responses.py
Original file line number Diff line number Diff line change
Expand Up @@ -882,6 +882,46 @@ def modify_db_cluster_snapshot_attribute(self) -> str:
db_cluster_snapshot_identifier=db_cluster_snapshot_identifier,
)

def describe_db_proxies(self) -> str:
params = self._get_params()
db_proxy_name = params.get("DBProxyName")
filters = params.get("Filters")
JoshLevyMN marked this conversation as resolved.
Show resolved Hide resolved
marker = params.get("Marker")
db_proxies, marker = self.backend.describe_db_proxies(
db_proxy_name=db_proxy_name,
filters=filters,
)
template = self.response_template(DESCRIBE_DB_PROXIES_TEMPLATE)
rendered = template.render(dbproxies=db_proxies, marker=marker)
return rendered

def create_db_proxy(self) -> str:
params = self._get_params()
db_proxy_name = params["DBProxyName"]
engine_family = params["EngineFamily"]
auth = params["Auth"]
role_arn = params["RoleArn"]
vpc_subnet_ids = params["VpcSubnetIds"]
vpc_security_group_ids = params.get("VpcSecurityGroupIds")
require_tls = params.get("RequireTLS")
idle_client_timeout = params.get("IdleClientTimeout")
debug_logging = params.get("DebugLogging")
tags = self.unpack_list_params("Tags", "Tag")
db_proxy = self.backend.create_db_proxy(
db_proxy_name=db_proxy_name,
engine_family=engine_family,
auth=auth,
role_arn=role_arn,
vpc_subnet_ids=vpc_subnet_ids,
vpc_security_group_ids=vpc_security_group_ids,
require_tls=require_tls,
idle_client_timeout=idle_client_timeout,
debug_logging=debug_logging,
tags=tags,
)
template = self.response_template(CREATE_DB_PROXY_TEMPLATE)
return template.render(dbproxy=db_proxy)


CREATE_DATABASE_TEMPLATE = """<CreateDBInstanceResponse xmlns="http://rds.amazonaws.com/doc/2014-09-01/">
<CreateDBInstanceResult>
Expand Down Expand Up @@ -1630,3 +1670,30 @@ def modify_db_cluster_snapshot_attribute(self) -> str:
<RequestId>1549581b-12b7-11e3-895e-1334a</RequestId>
</ResponseMetadata>
</DescribeDBClusterSnapshotAttributesResponse>"""

CREATE_DB_PROXY_TEMPLATE = """<CreateDBProxyResponse xmlns="http://rds.amazonaws.com/doc/2014-10-31/">
<CreateDBProxyResult>
<DBProxy>
{{ dbproxy.to_xml() }}
</DBProxy>
</CreateDBProxyResult>
<ResponseMetadata>
<RequestId>1549581b-12b7-11e3-895e-1334aEXAMPLE</RequestId>
</ResponseMetadata>
</CreateDBProxyResponse>"""

DESCRIBE_DB_PROXIES_TEMPLATE = """<DescribeDBProxiesResponse xmlns="http://rds.amazonaws.com/doc/2014-10-31/">
<DescribeDBProxiesResult>
<DBProxies>
{% for dbproxy in dbproxies %}
<member>
{{ dbproxy.to_xml() }}
</member>
{% endfor %}
</DBProxies>
</DescribeDBProxiesResult>
<ResponseMetadata>
<RequestId>1549581b-12b7-11e3-895e-1334a</RequestId>
</ResponseMetadata>
</DescribeDBProxiesResponse>
"""
1 change: 1 addition & 0 deletions moto/resourcegroupstaggingapi/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -433,6 +433,7 @@ def format_tag_keys(
"rds:db": self.rds_backend.databases,
"rds:snapshot": self.rds_backend.database_snapshots,
"rds:cluster-snapshot": self.rds_backend.cluster_snapshots,
"rds:db-proxy": self.rds_backend.db_proxies,
}
for resource_type, resource_source in resource_map.items():
if (
Expand Down