Skip to content

Commit

Permalink
Feature: RDS Proxy (#7329)
Browse files Browse the repository at this point in the history
  • Loading branch information
JoshLevyMN committed Feb 16, 2024
1 parent aa043a0 commit 59248f3
Show file tree
Hide file tree
Showing 5 changed files with 548 additions and 1 deletion.
32 changes: 32 additions & 0 deletions moto/rds/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,3 +218,35 @@ def __init__(self, instance_engine: str, cluster_engine: str) -> None:
f"The engine name requested for your DB instance ({instance_engine}) doesn't match "
f"the engine name of your DB cluster ({cluster_engine})."
)


class InvalidSubnet(RDSClientError):
def __init__(self, subnet_identifier: str):
super().__init__(
"InvalidSubnet",
f"The requested subnet {subnet_identifier} is invalid, or multiple subnets were requested that are not all in a common VPC.",
)


class DBProxyAlreadyExistsFault(RDSClientError):
def __init__(self, db_proxy_identifier: str):
super().__init__(
"DBProxyAlreadyExistsFault",
f"Cannot create the DBProxy because a DBProxy with the identifier {db_proxy_identifier} already exists.",
)


class DBProxyQuotaExceededFault(RDSClientError):
def __init__(self) -> None:
super().__init__(
"DBProxyQuotaExceeded",
"The request cannot be processed because it would exceed the maximum number of DBProxies.",
)


class DBProxyNotFoundFault(RDSClientError):
def __init__(self, db_proxy_identifier: str):
super().__init__(
"DBProxyNotFoundFault",
f"The specified proxy name {db_proxy_identifier} doesn't correspond to a proxy owned by your Amazon Web Services account in the specified Amazon Web Services Region.",
)
183 changes: 182 additions & 1 deletion moto/rds/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,9 @@
DBClusterToBeDeletedHasActiveMembers,
DBInstanceNotFoundError,
DBParameterGroupNotFoundError,
DBProxyAlreadyExistsFault,
DBProxyNotFoundFault,
DBProxyQuotaExceededFault,
DBSecurityGroupNotFoundError,
DBSnapshotAlreadyExistsError,
DBSnapshotNotFoundError,
Expand All @@ -39,6 +42,7 @@
InvalidGlobalClusterStateFault,
InvalidParameterCombination,
InvalidParameterValue,
InvalidSubnet,
OptionGroupNotFoundFaultError,
RDSClientError,
SnapshotQuotaExceededError,
Expand Down Expand Up @@ -1573,11 +1577,128 @@ def delete(self, account_id: str, region_name: str) -> None:
backend.delete_subnet_group(self.subnet_name)


class DBProxy(BaseModel):
def __init__(
self,
db_proxy_name: str,
engine_family: str,
auth: List[Dict[str, str]],
role_arn: str,
vpc_subnet_ids: List[str],
region_name: str,
account_id: str,
vpc_security_group_ids: Optional[List[str]],
require_tls: Optional[bool] = False,
idle_client_timeout: Optional[int] = 1800,
debug_logging: Optional[bool] = False,
tags: Optional[List[Dict[str, str]]] = None,
):
self.db_proxy_name = db_proxy_name
self.engine_family = engine_family
if self.engine_family not in ["MYSQL", "POSTGRESQ", "SQLSERVER"]:
raise InvalidParameterValue("Provided EngineFamily is not valid.")
self.auth = auth
self.role_arn = role_arn
self.vpc_subnet_ids = vpc_subnet_ids
self.vpc_security_group_ids = vpc_security_group_ids
self.require_tls = require_tls
if idle_client_timeout is None:
self.idle_client_timeout = 1800
else:
if int(idle_client_timeout) < 1:
self.idle_client_timeout = 1
elif int(idle_client_timeout) > 28800:
self.idle_client_timeout = 28800
else:
self.idle_client_timeout = idle_client_timeout
self.debug_logging = debug_logging
self.created_date = iso_8601_datetime_with_milliseconds()
self.updated_date = iso_8601_datetime_with_milliseconds()
if tags is None:
self.tags = []
else:
self.tags = tags
self.region_name = region_name
self.account_id = account_id
self.db_proxy_arn = f"arn:aws:rds:{self.region_name}:{self.account_id}:db-proxy:{self.db_proxy_name}"
self.arn = self.db_proxy_arn
ec2_backend = ec2_backends[self.account_id][self.region_name]
subnets = ec2_backend.describe_subnets(subnet_ids=self.vpc_subnet_ids)
vpcs = []
for subnet in subnets:
vpcs.append(subnet.vpc_id)
if subnet.vpc_id != vpcs[0]:
raise InvalidSubnet(subnet_identifier=subnet.id)

self.vpc_id = ec2_backend.describe_subnets(subnet_ids=[self.vpc_subnet_ids[0]])[
0
].vpc_id
self.status = "availible"
self.url_identifier = "".join(
random.choice(string.ascii_lowercase + string.digits) for _ in range(12)
)
self.endpoint = f"{self.db_proxy_name}.db-proxy-{self.url_identifier}.{self.region_name}.rds.amazonaws.com"

def get_tags(self) -> List[Dict[str, str]]:
return self.tags

def add_tags(self, tags: List[Dict[str, str]]) -> List[Dict[str, str]]:
new_keys = [tag_set["Key"] for tag_set in tags]
self.tags = [tag_set for tag_set in self.tags if tag_set["Key"] not in new_keys]
self.tags.extend(tags)
return self.tags

def remove_tags(self, tag_keys: List[str]) -> None:
self.tags = [tag_set for tag_set in self.tags if tag_set["Key"] not in tag_keys]

def to_xml(self) -> str:
template = Template(
"""
<RequireTLS>{{ dbproxy.require_tls }}</RequireTLS>
<VpcSecurityGroupIds>
{% if dbproxy.VpcSecurityGroupIds %}
{% for vpcsecuritygroupid in dbproxy.VpcSecurityGroupIds %}
<member>{{ vpcsecuritygroupid }}</member>
{% endfor %}
{% endif %}
</VpcSecurityGroupIds>
<Auth>
{% for auth in dbproxy.auth %}
<member>
<UserName>{{ auth["UserName"] }}</UserName>
<AuthScheme>{{ auth["AuthScheme"] }}</AuthScheme>
<SecretArn>{{ auth["SecretArn"] }}</SecretArn>
<IAMAuth>{{ auth["IAMAuth"] }}</IAMAuth>
<ClientPasswordAuthType>{{ auth["ClientPasswordAuthType"] }}</ClientPasswordAuthType>
</member>
{% endfor %}
</Auth>
<EngineFamily>{{ dbproxy.engine_family }}</EngineFamily>
<UpdatedDate>{{ dbproxy.updated_date }}</UpdatedDate>
<DBProxyName>{{ dbproxy.db_proxy_name }}</DBProxyName>
<IdleClientTimeout>{{ dbproxy.idle_client_timeout }}</IdleClientTimeout>
<Endpoint>{{ dbproxy.endpoint }}</Endpoint>
<CreatedDate>{{ dbproxy.created_date }}</CreatedDate>
<RoleArn>{{ dbproxy.role_arn }}</RoleArn>
<DebugLogging>{{ dbproxy.debug_logging }}</DebugLogging>
<VpcId>{{ dbproxy.vpc_id }}</VpcId>
<DBProxyArn>{{ dbproxy.db_proxy_arn }}</DBProxyArn>
<VpcSubnetIds>
{% for vpcsubnetid in dbproxy.vpc_subnet_ids %}
<member>{{ vpcsubnetid }}</member>
{% endfor %}
</VpcSubnetIds>
<Status>{{ dbproxy.status }}</Status>
"""
)
return template.render(dbproxy=self)


class RDSBackend(BaseBackend):
def __init__(self, region_name: str, account_id: str):
super().__init__(region_name, account_id)
self.arn_regex = re_compile(
r"^arn:aws:rds:.*:[0-9]*:(db|cluster|es|og|pg|ri|secgrp|snapshot|cluster-snapshot|subgrp):.*$"
r"^arn:aws:rds:.*:[0-9]*:(db|cluster|es|og|pg|ri|secgrp|snapshot|cluster-snapshot|subgrp|db-proxy):.*$"
)
self.clusters: Dict[str, Cluster] = OrderedDict()
self.global_clusters: Dict[str, GlobalCluster] = OrderedDict()
Expand All @@ -1592,6 +1713,7 @@ def __init__(self, region_name: str, account_id: str):
self.security_groups: Dict[str, SecurityGroup] = {}
self.subnet_groups: Dict[str, SubnetGroup] = {}
self._db_cluster_options: Optional[List[Dict[str, Any]]] = None
self.db_proxies: Dict[str, DBProxy] = OrderedDict()

def reset(self) -> None:
self.neptune.reset()
Expand Down Expand Up @@ -2586,6 +2708,9 @@ def list_tags_for_resource(self, arn: str) -> List[Dict[str, str]]:
elif resource_type == "subgrp": # DB subnet group
if resource_name in self.subnet_groups:
return self.subnet_groups[resource_name].get_tags()
elif resource_type == "db-proxy": # DB Proxy
if resource_name in self.db_proxies:
return self.db_proxies[resource_name].get_tags()
else:
raise RDSClientError(
"InvalidParameterValue", f"Invalid resource name: {arn}"
Expand Down Expand Up @@ -2628,6 +2753,9 @@ def remove_tags_from_resource(self, arn: str, tag_keys: List[str]) -> None:
elif resource_type == "subgrp": # DB subnet group
if resource_name in self.subnet_groups:
self.subnet_groups[resource_name].remove_tags(tag_keys)
elif resource_type == "db-proxy": # DB Proxy
if resource_name in self.db_proxies:
self.db_proxies[resource_name].remove_tags(tag_keys)
else:
raise RDSClientError(
"InvalidParameterValue", f"Invalid resource name: {arn}"
Expand Down Expand Up @@ -2669,6 +2797,9 @@ def add_tags_to_resource(self, arn: str, tags: List[Dict[str, str]]) -> List[Dic
elif resource_type == "subgrp": # DB subnet group
if resource_name in self.subnet_groups:
return self.subnet_groups[resource_name].add_tags(tags)
elif resource_type == "db-proxy": # DB Proxy
if resource_name in self.db_proxies:
return self.db_proxies[resource_name].add_tags(tags)
else:
raise RDSClientError(
"InvalidParameterValue", f"Invalid resource name: {arn}"
Expand Down Expand Up @@ -2910,6 +3041,56 @@ def modify_db_cluster_snapshot_attribute(
)
return snapshot.attributes

def create_db_proxy(
self,
db_proxy_name: str,
engine_family: str,
auth: List[Dict[str, str]],
role_arn: str,
vpc_subnet_ids: List[str],
vpc_security_group_ids: Optional[List[str]],
require_tls: Optional[bool],
idle_client_timeout: Optional[int],
debug_logging: Optional[bool],
tags: Optional[List[Dict[str, str]]],
) -> DBProxy:
self._validate_db_identifier(db_proxy_name)
if db_proxy_name in self.db_proxies:
raise DBProxyAlreadyExistsFault(db_proxy_name)
if len(self.db_proxies) >= int(os.environ.get("MOTO_RDS_PROXY_LIMIT", "100")):
raise DBProxyQuotaExceededFault()
db_proxy = DBProxy(
db_proxy_name,
engine_family,
auth,
role_arn,
vpc_subnet_ids,
self.region_name,
self.account_id,
vpc_security_group_ids,
require_tls,
idle_client_timeout,
debug_logging,
tags,
)
self.db_proxies[db_proxy_name] = db_proxy
return db_proxy

def describe_db_proxies(
self,
db_proxy_name: Optional[str],
filters: Optional[List[Dict[str, Any]]] = None,
) -> List[DBProxy]:
"""
The filters-argument is not yet supported
"""
db_proxies = list(self.db_proxies.values())
if db_proxy_name and db_proxy_name in self.db_proxies.keys():
db_proxies = [self.db_proxies[db_proxy_name]]
if db_proxy_name and db_proxy_name not in self.db_proxies.keys():
raise DBProxyNotFoundFault(db_proxy_name)
return db_proxies


class OptionGroup:
def __init__(
Expand Down
67 changes: 67 additions & 0 deletions moto/rds/responses.py
Original file line number Diff line number Diff line change
Expand Up @@ -882,6 +882,46 @@ def modify_db_cluster_snapshot_attribute(self) -> str:
db_cluster_snapshot_identifier=db_cluster_snapshot_identifier,
)

def describe_db_proxies(self) -> str:
params = self._get_params()
db_proxy_name = params.get("DBProxyName")
# filters = params.get("Filters")
marker = params.get("Marker")
db_proxies = self.backend.describe_db_proxies(
db_proxy_name=db_proxy_name,
# filters=filters,
)
template = self.response_template(DESCRIBE_DB_PROXIES_TEMPLATE)
rendered = template.render(dbproxies=db_proxies, marker=marker)
return rendered

def create_db_proxy(self) -> str:
params = self._get_params()
db_proxy_name = params["DBProxyName"]
engine_family = params["EngineFamily"]
auth = params["Auth"]
role_arn = params["RoleArn"]
vpc_subnet_ids = params["VpcSubnetIds"]
vpc_security_group_ids = params.get("VpcSecurityGroupIds")
require_tls = params.get("RequireTLS")
idle_client_timeout = params.get("IdleClientTimeout")
debug_logging = params.get("DebugLogging")
tags = self.unpack_list_params("Tags", "Tag")
db_proxy = self.backend.create_db_proxy(
db_proxy_name=db_proxy_name,
engine_family=engine_family,
auth=auth,
role_arn=role_arn,
vpc_subnet_ids=vpc_subnet_ids,
vpc_security_group_ids=vpc_security_group_ids,
require_tls=require_tls,
idle_client_timeout=idle_client_timeout,
debug_logging=debug_logging,
tags=tags,
)
template = self.response_template(CREATE_DB_PROXY_TEMPLATE)
return template.render(dbproxy=db_proxy)


CREATE_DATABASE_TEMPLATE = """<CreateDBInstanceResponse xmlns="http://rds.amazonaws.com/doc/2014-09-01/">
<CreateDBInstanceResult>
Expand Down Expand Up @@ -1630,3 +1670,30 @@ def modify_db_cluster_snapshot_attribute(self) -> str:
<RequestId>1549581b-12b7-11e3-895e-1334a</RequestId>
</ResponseMetadata>
</DescribeDBClusterSnapshotAttributesResponse>"""

CREATE_DB_PROXY_TEMPLATE = """<CreateDBProxyResponse xmlns="http://rds.amazonaws.com/doc/2014-10-31/">
<CreateDBProxyResult>
<DBProxy>
{{ dbproxy.to_xml() }}
</DBProxy>
</CreateDBProxyResult>
<ResponseMetadata>
<RequestId>1549581b-12b7-11e3-895e-1334aEXAMPLE</RequestId>
</ResponseMetadata>
</CreateDBProxyResponse>"""

DESCRIBE_DB_PROXIES_TEMPLATE = """<DescribeDBProxiesResponse xmlns="http://rds.amazonaws.com/doc/2014-10-31/">
<DescribeDBProxiesResult>
<DBProxies>
{% for dbproxy in dbproxies %}
<member>
{{ dbproxy.to_xml() }}
</member>
{% endfor %}
</DBProxies>
</DescribeDBProxiesResult>
<ResponseMetadata>
<RequestId>1549581b-12b7-11e3-895e-1334a</RequestId>
</ResponseMetadata>
</DescribeDBProxiesResponse>
"""
1 change: 1 addition & 0 deletions moto/resourcegroupstaggingapi/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -456,6 +456,7 @@ def format_tag_keys(
"rds:db": self.rds_backend.databases,
"rds:snapshot": self.rds_backend.database_snapshots,
"rds:cluster-snapshot": self.rds_backend.cluster_snapshots,
"rds:db-proxy": self.rds_backend.db_proxies,
}
for resource_type, resource_source in resource_map.items():
if (
Expand Down

0 comments on commit 59248f3

Please sign in to comment.