support external postgresql, mysql and mariadb databases, resolves #72

This commit is contained in:
meeb
2021-04-04 23:01:15 +10:00
parent 3ec4f7c525
commit 20df9f4044
10 changed files with 298 additions and 19 deletions

View File

@@ -20,3 +20,10 @@ class DownloadFailedException(Exception):
exist.
'''
pass
class DatabaseConnectionError(Exception):
'''
Raised when parsing or initially connecting to a database.
'''
pass

View File

@@ -2,6 +2,8 @@ import os.path
from django.conf import settings
from django.test import TestCase, Client
from .testutils import prevent_request_warnings
from .utils import parse_database_connection_string
from .errors import DatabaseConnectionError
class ErrorPageTestCase(TestCase):
@@ -61,3 +63,40 @@ class CommonStaticTestCase(TestCase):
favicon_real_path = os.path.join(os.sep.join(root_parts),
os.sep.join(url_parts))
self.assertTrue(os.path.exists(favicon_real_path))
class DatabaseConnectionTestCase(TestCase):
def test_parse_database_connection_string(self):
database_dict = parse_database_connection_string(
'postgresql://tubesync:password@localhost:5432/tubesync')
database_dict = parse_database_connection_string(
'mysql://tubesync:password@localhost:3306/tubesync')
# Invalid driver
with self.assertRaises(DatabaseConnectionError):
parse_database_connection_string(
'test://tubesync:password@localhost:5432/tubesync')
# No username
with self.assertRaises(DatabaseConnectionError):
parse_database_connection_string(
'postgresql://password@localhost:5432/tubesync')
# No database name
with self.assertRaises(DatabaseConnectionError):
parse_database_connection_string(
'postgresql://tubesync:password@5432')
# Invalid port
with self.assertRaises(DatabaseConnectionError):
parse_database_connection_string(
'postgresql://tubesync:password@localhost:test/tubesync')
# Invalid port
with self.assertRaises(DatabaseConnectionError):
parse_database_connection_string(
'postgresql://tubesync:password@localhost:65537/tubesync')
# Invalid username or password
with self.assertRaises(DatabaseConnectionError):
parse_database_connection_string(
'postgresql://tubesync:password:test@localhost:5432/tubesync')
# Invalid database name
with self.assertRaises(DatabaseConnectionError):
parse_database_connection_string(
'postgresql://tubesync:password@localhost:5432/tubesync/test')

View File

@@ -1,4 +1,85 @@
from urllib.parse import urlunsplit, urlencode
from urllib.parse import urlunsplit, urlencode, urlparse
from .errors import DatabaseConnectionError
def parse_database_connection_string(database_connection_string):
'''
Parses a connection string in a URL style format, such as:
postgresql://tubesync:password@localhost:5432/tubesync
mysql://someuser:somepassword@localhost:3306/tubesync
into a Django-compatible settings.DATABASES dict format.
'''
valid_drivers = ('postgresql', 'mysql')
default_ports = {
'postgresql': 5432,
'mysql': 3306,
}
django_backends = {
'postgresql': 'django.db.backends.postgresql',
'mysql': 'django.db.backends.mysql',
}
try:
parts = urlparse(str(database_connection_string))
except Exception as e:
raise DatabaseConnectionError(f'Failed to parse "{database_connection_string}" '
f'as a database connection string: {e}') from e
driver = parts.scheme
user_pass_host_port = parts.netloc
database = parts.path
if driver not in valid_drivers:
raise DatabaseConnectionError(f'Database connection string '
f'"{database_connection_string}" specified an '
f'invalid driver, must be one of {valid_drivers}')
django_driver = django_backends.get(driver)
host_parts = user_pass_host_port.split('@')
if len(host_parts) != 2:
raise DatabaseConnectionError(f'Database connection string netloc must be in '
f'the format of user:pass@host')
user_pass, host_port = host_parts
user_pass_parts = user_pass.split(':')
if len(user_pass_parts) != 2:
raise DatabaseConnectionError(f'Database connection string netloc must be in '
f'the format of user:pass@host')
username, password = user_pass_parts
host_port_parts = host_port.split(':')
if len(host_port_parts) == 1:
# No port number, assign a default port
hostname = host_port_parts[0]
port = default_ports.get(driver)
elif len(host_port_parts) == 2:
# Host name and port number
hostname, port = host_port_parts
try:
port = int(port)
except (ValueError, TypeError) as e:
raise DatabaseConnectionError(f'Database connection string contained an '
f'invalid port, ports must be integers: '
f'{e}') from e
if not 0 < port < 63336:
raise DatabaseConnectionError(f'Database connection string contained an '
f'invalid port, ports must be between 1 and '
f'65535, got {port}')
else:
# Malformed
raise DatabaseConnectionError(f'Database connection host must be a hostname or '
f'a hostname:port combination')
if database.startswith('/'):
database = database[1:]
if not database:
raise DatabaseConnectionError(f'Database connection string path must be a '
f'string in the format of /databasename')
if '/' in database:
raise DatabaseConnectionError(f'Database connection string path can only '
f'contain a single string name, got: {database}')
return {
'DRIVER': driver,
'ENGINE': django_driver,
'NAME': database,
'USER': username,
'PASSWORD': password,
'HOST': hostname,
'PORT': port,
}
def get_client_ip(request):

View File

@@ -123,6 +123,10 @@
<td class="hide-on-small-only">Downloads directory</td>
<td><span class="hide-on-med-and-up">Downloads directory<br></span><strong>{{ downloads_dir }}</strong></td>
</tr>
<tr title="Database connection used by TubeSync">
<td class="hide-on-small-only">Database</td>
<td><span class="hide-on-med-and-up">Database<br></span><strong>{{ database_connection }}</strong></td>
</tr>
</table>
</div>
</div>

View File

@@ -78,6 +78,7 @@ class DashboardView(TemplateView):
# Config and download locations
data['config_dir'] = str(settings.CONFIG_BASE_DIR)
data['downloads_dir'] = str(settings.DOWNLOAD_ROOT)
data['database_connection'] = settings.DATABASE_CONNECTION_STR
return data

View File

@@ -1,5 +1,7 @@
import os
from pathlib import Path
from common.logger import log
from common.utils import parse_database_connection_string
BASE_DIR = Path(__file__).resolve().parent.parent
@@ -21,12 +23,31 @@ FORCE_SCRIPT_NAME = os.getenv('DJANGO_FORCE_SCRIPT_NAME', None)
TIME_ZONE = os.getenv('TZ', 'UTC')
DATABASES = {
'default': {
'ENGINE': 'django.db.backends.sqlite3',
'NAME': CONFIG_BASE_DIR / 'db.sqlite3',
database_dict = {}
database_connection_env = os.getenv('DATABASE_CONNECTION', '')
if database_connection_env:
database_dict = parse_database_connection_string(database_connection_env)
if database_dict:
log.info(f'Using database connection: {database_dict["ENGINE"]}://'
f'{database_dict["USER"]}:[hidden]@{database_dict["HOST"]}:'
f'{database_dict["PORT"]}/{database_dict["NAME"]}')
DATABASES = {
'default': database_dict,
}
}
DATABASE_CONNECTION_STR = (f'{database_dict["DRIVER"]} at "{database_dict["HOST"]}:'
f'{database_dict["PORT"]}" database '
f'"{database_dict["NAME"]}"')
else:
DATABASES = {
'default': {
'ENGINE': 'django.db.backends.sqlite3',
'NAME': CONFIG_BASE_DIR / 'db.sqlite3',
}
}
DATABASE_CONNECTION_STR = f'sqlite at "{DATABASES["default"]["NAME"]}"'
DEFAULT_THREADS = 1
MAX_BACKGROUND_TASK_ASYNC_THREADS = 8