tubesync/tubesync/common/utils.py

132 lines
4.9 KiB
Python

import string
from datetime import datetime
from urllib.parse import urlunsplit, urlencode, urlparse
from yt_dlp.utils import LazyList
from .errors import DatabaseConnectionError
def parse_database_connection_string(database_connection_string):
'''
Parses a connection string in a URL style format, such as:
postgresql://tubesync:password@localhost:5432/tubesync
mysql://someuser:somepassword@localhost:3306/tubesync
into a Django-compatible settings.DATABASES dict format.
'''
valid_drivers = ('postgresql', 'mysql')
default_ports = {
'postgresql': 5432,
'mysql': 3306,
}
django_backends = {
'postgresql': 'django.db.backends.postgresql',
'mysql': 'django.db.backends.mysql',
}
backend_options = {
'postgresql': {},
'mysql': {
'charset': 'utf8mb4',
}
}
try:
parts = urlparse(str(database_connection_string))
except Exception as e:
raise DatabaseConnectionError(f'Failed to parse "{database_connection_string}" '
f'as a database connection string: {e}') from e
driver = parts.scheme
user_pass_host_port = parts.netloc
database = parts.path
if driver not in valid_drivers:
raise DatabaseConnectionError(f'Database connection string '
f'"{database_connection_string}" specified an '
f'invalid driver, must be one of {valid_drivers}')
django_driver = django_backends.get(driver)
host_parts = user_pass_host_port.split('@')
if len(host_parts) != 2:
raise DatabaseConnectionError(f'Database connection string netloc must be in '
f'the format of user:pass@host')
user_pass, host_port = host_parts
user_pass_parts = user_pass.split(':')
if len(user_pass_parts) != 2:
raise DatabaseConnectionError(f'Database connection string netloc must be in '
f'the format of user:pass@host')
username, password = user_pass_parts
host_port_parts = host_port.split(':')
if len(host_port_parts) == 1:
# No port number, assign a default port
hostname = host_port_parts[0]
port = default_ports.get(driver)
elif len(host_port_parts) == 2:
# Host name and port number
hostname, port = host_port_parts
try:
port = int(port)
except (ValueError, TypeError) as e:
raise DatabaseConnectionError(f'Database connection string contained an '
f'invalid port, ports must be integers: '
f'{e}') from e
if not 0 < port < 63336:
raise DatabaseConnectionError(f'Database connection string contained an '
f'invalid port, ports must be between 1 and '
f'65535, got {port}')
else:
# Malformed
raise DatabaseConnectionError(f'Database connection host must be a hostname or '
f'a hostname:port combination')
if database.startswith('/'):
database = database[1:]
if not database:
raise DatabaseConnectionError(f'Database connection string path must be a '
f'string in the format of /databasename')
if '/' in database:
raise DatabaseConnectionError(f'Database connection string path can only '
f'contain a single string name, got: {database}')
return {
'DRIVER': driver,
'ENGINE': django_driver,
'NAME': database,
'USER': username,
'PASSWORD': password,
'HOST': hostname,
'PORT': port,
'CONN_MAX_AGE': 300,
'OPTIONS': backend_options.get(driver),
}
def get_client_ip(request):
x_forwarded_for = request.META.get('HTTP_X_FORWARDED_FOR')
if x_forwarded_for:
ip = x_forwarded_for.split(',')[0]
else:
ip = request.META.get('REMOTE_ADDR')
return ip
def append_uri_params(uri, params):
uri = str(uri)
qs = urlencode(params)
return urlunsplit(('', '', uri, qs, ''))
def clean_filename(filename):
if not isinstance(filename, str):
raise ValueError(f'filename must be a str, got {type(filename)}')
to_scrub = '<>\/:*?"|%'
for char in to_scrub:
filename = filename.replace(char, '')
clean_filename = ''
for c in filename:
if c in string.whitespace:
c = ' '
if ord(c) > 30:
clean_filename += c
return clean_filename.strip()
def json_serial(obj):
if isinstance(obj, datetime):
return obj.isoformat()
if isinstance(obj, LazyList):
return list(obj)
raise TypeError(f'Type {type(obj)} is not json_serial()-able')