Skip to content

CredentialsProvider class added to support password rotation #2261

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 24 commits into from
Nov 10, 2022
Merged
Show file tree
Hide file tree
Changes from 18 commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
86da7ff
A CredentialsProvider class has been added to allow the user to add h…
barshaul Jul 5, 2022
5dfddde
Moved CredentialsProvider to a separate file, added type hints
barshaul Jul 18, 2022
243b244
Changed username and password to properties
barshaul Aug 9, 2022
af8e560
Added: StaticCredentialProvider, examples, tests
barshaul Aug 23, 2022
2261cb0
Changed private members' prefix to __
barshaul Aug 30, 2022
ddfe1ea
fixed linters
barshaul Sep 1, 2022
b481067
fixed auth test
barshaul Sep 4, 2022
686d172
fixed credential test
barshaul Sep 4, 2022
d1d10af
Raise an error if username or password are passed along with credenti…
barshaul Sep 20, 2022
9de8d21
fixing linters
barshaul Sep 20, 2022
def996b
fixing test
barshaul Sep 21, 2022
29c8006
Changed dundered to single per side underscore
barshaul Oct 2, 2022
6b8cf1f
Changed Connection class members username and password to properties …
barshaul Oct 2, 2022
c37e0f1
Reverting last commit and adding backward compatibility to 'username'…
barshaul Oct 3, 2022
abe6137
Refactored CredentialProvider class
barshaul Nov 2, 2022
6303243
Fixing tuple type to Tuple
barshaul Nov 2, 2022
057ed82
Fixing optional string members in UsernamePasswordCredentialProvider
barshaul Nov 2, 2022
ba91b0f
Fixed credential test
barshaul Nov 2, 2022
6223901
Added credential provider support to AsyncRedis
barshaul Nov 9, 2022
fac8333
Merge branch 'master' into creds_provider
dvora-h Nov 10, 2022
b951e19
linters
dvora-h Nov 10, 2022
72c366d
linters
dvora-h Nov 10, 2022
4b35cb2
linters
dvora-h Nov 10, 2022
4c82551
linters - black
dvora-h Nov 10, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGES
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
* ClusterPipeline Doesn't Handle ConnectionError for Dead Hosts (#2225)
* Remove compatibility code for old versions of Hiredis, drop Packaging dependency
* The `deprecated` library is no longer a dependency
* Added CredentialsProvider class to support password rotation

* 4.1.3 (Feb 8, 2022)
* Fix flushdb and flushall (#1926)
Expand Down
193 changes: 192 additions & 1 deletion docs/examples/connection_examples.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,197 @@
"user_connection.ping()"
]
},
{
"cell_type": "markdown",
"source": [
"## Connecting to a redis instance with username and password credential provider"
],
"metadata": {}
},
{
"cell_type": "code",
"execution_count": null,
"outputs": [],
"source": [
"import redis\n",
"\n",
"creds_provider = redis.UsernamePasswordCredentialProvider(\"username\", \"password\")\n",
"user_connection = redis.Redis(host=\"localhost\", port=6379, credential_provider=creds_provider)\n",
"user_connection.ping()"
],
"metadata": {}
}
},
{
"cell_type": "markdown",
"source": [
"## Connecting to a redis instance with standard credential provider"
],
"metadata": {}
}
},
{
"cell_type": "code",
"execution_count": null,
"outputs": [],
"source": [
"from typing import Tuple\n",
"import redis\n",
"\n",
"creds_map = {\"user_1\": \"pass_1\",\n",
" \"user_2\": \"pass_2\"}\n",
"\n",
"class UserMapCredentialProvider(redis.CredentialProvider):\n",
" def __init__(self, username: str):\n",
" self.username = username\n",
"\n",
" def get_credentials(self) -> Tuple[str, str]:\n",
" return self.username, creds_map.get(self.username)\n",
"\n",
"# Create a default connection to set the ACL user\n",
"default_connection = redis.Redis(host=\"localhost\", port=6379)\n",
"default_connection.acl_setuser(\n",
" \"user_1\",\n",
" enabled=True,\n",
" passwords=[\"+\" + \"pass_1\"],\n",
" keys=\"~*\",\n",
" commands=[\"+ping\", \"+command\", \"+info\", \"+select\", \"+flushdb\"],\n",
")\n",
"\n",
"# Create a UserMapCredentialProvider instance for user_1\n",
"creds_provider = UserMapCredentialProvider(\"user_1\")\n",
"# Initiate user connection with the credential provider\n",
"user_connection = redis.Redis(host=\"localhost\", port=6379,\n",
" credential_provider=creds_provider)\n",
"user_connection.ping()"
],
"metadata": {}
}
},
{
"cell_type": "markdown",
"source": [
"## Connecting to a redis instance first with an initial credential set and then calling the credential provider"
],
"metadata": {}
}
},
{
"cell_type": "code",
"execution_count": null,
"outputs": [],
"source": [
"from typing import Union\n",
"import redis\n",
"\n",
"class InitCredsSetCredentialProvider(redis.CredentialProvider):\n",
" def __init__(self, username, password):\n",
" self.username = username\n",
" self.password = password\n",
" self.call_supplier = False\n",
"\n",
" def call_external_supplier(self) -> Union[Tuple[str], Tuple[str, str]]:\n",
" # Call to an external credential supplier\n",
" raise NotImplementedError\n",
"\n",
" def get_credentials(self) -> Union[Tuple[str], Tuple[str, str]]:\n",
" if self.call_supplier:\n",
" return self.call_external_supplier()\n",
" # Use the init set only for the first time\n",
" self.call_supplier = True\n",
" return self.username, self.password\n",
"\n",
"cred_provider = InitCredsSetCredentialProvider(username=\"init_user\", password=\"init_pass\")"
],
"metadata": {}
}
},
{
"cell_type": "markdown",
"metadata": {
"collapsed": false
},
"source": [
"## Connecting to a redis instance with AWS Secrets Manager credential provider."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": false,
"pycharm": {
"name": "#%%\n"
}
},
"outputs": [],
"source": [
"import redis\n",
"import boto3\n",
"import json\n",
"import cachetools.func\n",
"\n",
"sm_client = boto3.client('secretsmanager')\n",
" \n",
"def sm_auth_provider(self, secret_id, version_id=None, version_stage='AWSCURRENT'):\n",
" @cachetools.func.ttl_cache(maxsize=128, ttl=24 * 60 * 60) #24h\n",
" def get_sm_user_credentials(secret_id, version_id, version_stage):\n",
" secret = sm_client.get_secret_value(secret_id, version_id)\n",
" return json.loads(secret['SecretString'])\n",
" creds = get_sm_user_credentials(secret_id, version_id, version_stage)\n",
" return creds['username'], creds['password']\n",
"\n",
"secret_id = \"EXAMPLE1-90ab-cdef-fedc-ba987SECRET1\"\n",
"creds_provider = redis.CredentialProvider(supplier=sm_auth_provider, secret_id=secret_id)\n",
"user_connection = redis.Redis(host=\"localhost\", port=6379, credential_provider=creds_provider)\n",
"user_connection.ping()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Connecting to a redis instance with ElastiCache IAM credential provider."
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"True"
]
},
"execution_count": 4,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"import redis\n",
"import boto3\n",
"import cachetools.func\n",
"\n",
"ec_client = boto3.client('elasticache')\n",
"\n",
"def iam_auth_provider(self, user, endpoint, port=6379, region=\"us-east-1\"):\n",
" @cachetools.func.ttl_cache(maxsize=128, ttl=15 * 60) # 15m\n",
" def get_iam_auth_token(user, endpoint, port, region):\n",
" return ec_client.generate_iam_auth_token(user, endpoint, port, region)\n",
" iam_auth_token = get_iam_auth_token(endpoint, port, user, region)\n",
" return iam_auth_token\n",
"\n",
"username = \"barshaul\"\n",
"endpoint = \"test-001.use1.cache.amazonaws.com\"\n",
"creds_provider = redis.CredentialProvider(supplier=iam_auth_provider, user=username,\n",
" endpoint=endpoint)\n",
"user_connection = redis.Redis(host=endpoint, port=6379, credential_provider=creds_provider)\n",
"user_connection.ping()"
]
},
{
"cell_type": "markdown",
"metadata": {},
Expand Down Expand Up @@ -176,4 +367,4 @@
},
"nbformat": 4,
"nbformat_minor": 2
}
}
3 changes: 3 additions & 0 deletions redis/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
SSLConnection,
UnixDomainSocketConnection,
)
from redis.credentials import CredentialProvider, UsernamePasswordCredentialProvider
from redis.exceptions import (
AuthenticationError,
AuthenticationWrongNumberOfArgsError,
Expand Down Expand Up @@ -62,6 +63,7 @@ def int_or_str(value):
"Connection",
"ConnectionError",
"ConnectionPool",
"CredentialProvider",
"DataError",
"from_url",
"InvalidResponse",
Expand All @@ -76,6 +78,7 @@ def int_or_str(value):
"SentinelManagedConnection",
"SentinelManagedSSLConnection",
"SSLConnection",
"UsernamePasswordCredentialProvider",
"StrictRedis",
"TimeoutError",
"UnixDomainSocketConnection",
Expand Down
4 changes: 4 additions & 0 deletions redis/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import time
import warnings
from itertools import chain
from typing import Optional

from redis.commands import (
CoreCommands,
Expand All @@ -13,6 +14,7 @@
list_or_args,
)
from redis.connection import ConnectionPool, SSLConnection, UnixDomainSocketConnection
from redis.credentials import CredentialProvider
from redis.exceptions import (
ConnectionError,
ExecAbortError,
Expand Down Expand Up @@ -938,6 +940,7 @@ def __init__(
username=None,
retry=None,
redis_connect_func=None,
credential_provider: Optional[CredentialProvider] = None,
):
"""
Initialize a new Redis client.
Expand Down Expand Up @@ -985,6 +988,7 @@ def __init__(
"health_check_interval": health_check_interval,
"client_name": client_name,
"redis_connect_func": redis_connect_func,
"credential_provider": credential_provider,
}
# based on input, setup appropriate connection args
if unix_socket_path is not None:
Expand Down
1 change: 1 addition & 0 deletions redis/cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,7 @@ def parse_cluster_shards(resp, **options):
"connection_class",
"connection_pool",
"client_name",
"credential_provider",
"db",
"decode_responses",
"encoding",
Expand Down
40 changes: 31 additions & 9 deletions redis/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,11 @@
from itertools import chain
from queue import Empty, Full, LifoQueue
from time import time
from typing import Optional
from urllib.parse import parse_qs, unquote, urlparse

from redis.backoff import NoBackoff
from redis.credentials import CredentialProvider, UsernamePasswordCredentialProvider
from redis.exceptions import (
AuthenticationError,
AuthenticationWrongNumberOfArgsError,
Expand Down Expand Up @@ -502,6 +504,7 @@ def __init__(
username=None,
retry=None,
redis_connect_func=None,
credential_provider: Optional[CredentialProvider] = None,
):
"""
Initialize a new Connection.
Expand All @@ -514,9 +517,18 @@ def __init__(
self.host = host
self.port = int(port)
self.db = db
self.username = username
self.client_name = client_name
if (username or password) and credential_provider is not None:
raise DataError(
"'username' and 'password' cannot be passed along with 'credential_"
"provider'. Please provide only one of the following arguments: \n"
"1. 'password' and (optional) 'username'\n"
"2. 'credential_provider'"
)

self.credential_provider = credential_provider
self.password = password
self.username = username
self.socket_timeout = socket_timeout
self.socket_connect_timeout = socket_connect_timeout or socket_timeout
self.socket_keepalive = socket_keepalive
Expand Down Expand Up @@ -675,12 +687,13 @@ def on_connect(self):
"Initialize the connection, authenticate and select a database"
self._parser.on_connect(self)

# if username and/or password are set, authenticate
if self.username or self.password:
if self.username:
auth_args = (self.username, self.password or "")
else:
auth_args = (self.password,)
# if credential provider or username and/or password are set, authenticate
if self.credential_provider or (self.username or self.password):
cred_provider = (
self.credential_provider
or UsernamePasswordCredentialProvider(self.username, self.password)
)
auth_args = cred_provider.get_credentials()
# avoid checking health here -- PING will fail if we try
# to check the health prior to the AUTH
self.send_command("AUTH", *auth_args, check_health=False)
Expand All @@ -692,7 +705,7 @@ def on_connect(self):
# server seems to be < 6.0.0 which expects a single password
# arg. retry auth with just the password.
# https://github.com/andymccurdy/redis-py/issues/1274
self.send_command("AUTH", self.password, check_health=False)
self.send_command("AUTH", auth_args[-1], check_health=False)
auth_response = self.read_response()

if str_if_bytes(auth_response) != "OK":
Expand Down Expand Up @@ -1050,6 +1063,7 @@ def __init__(
client_name=None,
retry=None,
redis_connect_func=None,
credential_provider: Optional[CredentialProvider] = None,
):
"""
Initialize a new UnixDomainSocketConnection.
Expand All @@ -1061,9 +1075,17 @@ def __init__(
self.pid = os.getpid()
self.path = path
self.db = db
self.username = username
self.client_name = client_name
if (username or password) and credential_provider is not None:
raise DataError(
"'username' and 'password' cannot be passed along with 'credential_"
"provider'. Please provide only one of the following arguments: \n"
"1. 'password' and (optional) 'username'\n"
"2. 'credential_provider'"
)
self.credential_provider = credential_provider
self.password = password
self.username = username
self.socket_timeout = socket_timeout
self.retry_on_timeout = retry_on_timeout
if retry_on_error is SENTINEL:
Expand Down
26 changes: 26 additions & 0 deletions redis/credentials.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
from typing import Optional, Tuple, Union


class CredentialProvider:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure why this needs to be a class at all – couldn't a provider just be a function get_credentials(*, username, password)...

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Different credential providers will need different arguments to be passed, so we can save within the credential provider object the required args/kwargs.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You're right: a credential provider could indeed just be a get_credentials() function with no knowledge of any arguments. 😉

def get_very_specific_credential_provider(some_auth_url):
    def provider():
        # ... do some API call..?
        return ("foo", "bar")
    return provider

c = Client(
    credential_provider=get_very_specific_credential_provider("https://secret..."),
)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yup, that would do the job. :)
@chayim WDYT?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There's no reason that during the instantiation of all Redis connection objects a function can't be passed in, as opposed to the class being discussed. However, I think the current class implementation is more readable for the community that uses this. I'm partial for erring towards readability, and the clear separation it provides.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@chayim @barshaul Well, this class looks like it's become quite complicated and hard-to-follow.

It has an username/password that may not be used at all, args and kwargs that may not be used at all.

If you want a class, then why not just

class CredentialProvider:
    def get_credentials(self) -> Union[Tuple[str], Tuple[str, str]]:
        raise NotImplementedError("get_credentials must be implemented")


class UsernamePasswordCredentialProvider(CredentialProvider):
    def __init__(self, username: Optional[str]=None, password: Optional[str]=None):
        self.username = username
        self.password = password

    def get_credentials(self):
        if self.username:
            return (self.username, self._password)
        return (self.password,)

and leave the rest (e.g. implementing a specific provider with a specific get_credentials) to the users?

As an aside – should get_credentials maybe also support being async? :)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

With the changes now done to Connection, I really do believe that the implementation I suggest above would be the better fit, and the current CredentialProvider is way too generic. I sincerely ask you to revise it.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

get_credentials should definitely support an async all, to keep this aligned.

On the class inputs side - I don't think it's unreasonable to pass in an optional CredentialProvider on a per redis-connection type. I see classes as more useful given state needs, depending on the provider. Maybe instead, CredentialProvider shouldn't require any specified arguments, instead be a descendent of an ABC that accepts **kwargs. This could address the same need, and reduce the complexity, but yes - be perhaps generic enough (though perhaps something shared).

@akx WDYT needlessly generic? Having a way to do this for systems that are completely different makes it problematic.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@chayim The "needlessly generic" comment was outdated by abe6137, which removed the "username + password + supplier + args + kwargs" mode.

"""
Credentials Provider.
"""

def get_credentials(self) -> Union[Tuple[str], Tuple[str, str]]:
raise NotImplementedError("get_credentials must be implemented")


class UsernamePasswordCredentialProvider(CredentialProvider):
"""
Simple implementation of CredentialProvider that just wraps static
username and password.
"""

def __init__(self, username: Optional[str] = None, password: Optional[str] = None):
self.username = username or ""
self.password = password or ""

def get_credentials(self):
if self.username:
return self.username, self.password
return (self.password,)
Loading