diff --git a/tubesync/common/tests.py b/tubesync/common/tests.py index fc93e3e..1bdd16a 100644 --- a/tubesync/common/tests.py +++ b/tubesync/common/tests.py @@ -2,7 +2,7 @@ import os.path from django.conf import settings from django.test import TestCase, Client from .testutils import prevent_request_warnings -from .utils import parse_database_connection_string +from .utils import parse_database_connection_string, clean_filename from .errors import DatabaseConnectionError @@ -65,7 +65,7 @@ class CommonStaticTestCase(TestCase): self.assertTrue(os.path.exists(favicon_real_path)) -class DatabaseConnectionTestCase(TestCase): +class UtilsTestCase(TestCase): def test_parse_database_connection_string(self): database_dict = parse_database_connection_string( @@ -126,3 +126,12 @@ class DatabaseConnectionTestCase(TestCase): with self.assertRaises(DatabaseConnectionError): parse_database_connection_string( 'postgresql://tubesync:password@localhost:5432/tubesync/test') + + def test_clean_filename(self): + self.assertEqual(clean_filename('a'), 'a') + self.assertEqual(clean_filename('a\t'), 'a') + self.assertEqual(clean_filename('a\n'), 'a') + self.assertEqual(clean_filename('a a'), 'a a') + self.assertEqual(clean_filename('a a'), 'a a') + self.assertEqual(clean_filename('a\t\t\ta'), 'a a') + self.assertEqual(clean_filename('a\t\t\ta\t\t\t'), 'a a') diff --git a/tubesync/common/utils.py b/tubesync/common/utils.py index 0b2dee8..4cb9f93 100644 --- a/tubesync/common/utils.py +++ b/tubesync/common/utils.py @@ -1,3 +1,4 @@ +import string from datetime import datetime from urllib.parse import urlunsplit, urlencode, urlparse from yt_dlp.utils import LazyList @@ -113,8 +114,13 @@ def clean_filename(filename): to_scrub = '<>\/:*?"|%' for char in to_scrub: filename = filename.replace(char, '') - filename = ''.join([c for c in filename if ord(c) > 30]) - return ' '.join(filename.split()) + clean_filename = '' + for c in filename: + if c in string.whitespace: + c = ' ' + if ord(c) > 30: + clean_filename += c + return clean_filename.strip() def json_serial(obj):