diff --git a/changelog.d/18527.feature b/changelog.d/18527.feature new file mode 100644 index 00000000000..3394f6786f4 --- /dev/null +++ b/changelog.d/18527.feature @@ -0,0 +1 @@ +Add ability to limit amount uploaded by a user in a given time period. diff --git a/docs/usage/configuration/config_documentation.md b/docs/usage/configuration/config_documentation.md index c014de794d2..401c1ff7db9 100644 --- a/docs/usage/configuration/config_documentation.md +++ b/docs/usage/configuration/config_documentation.md @@ -2058,6 +2058,23 @@ Example configuration: max_upload_size: 60M ``` --- +### `media_upload_limits` + +*(array)* A list of media upload limits defining how much data a given user can upload in a given time period. + +An empty list means no limits are applied. + +Defaults to `[]`. + +Example configuration: +```yaml +media_upload_limits: +- time_period: 1h + max_size: 100M +- time_period: 1w + max_size: 500M +``` +--- ### `max_image_pixels` *(byte size)* Maximum number of pixels that will be thumbnailed. Defaults to `"32M"`. diff --git a/schema/synapse-config.schema.yaml b/schema/synapse-config.schema.yaml index 5ebe80f51cb..84f8a664738 100644 --- a/schema/synapse-config.schema.yaml +++ b/schema/synapse-config.schema.yaml @@ -2300,6 +2300,30 @@ properties: default: 50M examples: - 60M + media_upload_limits: + type: array + description: >- + A list of media upload limits defining how much data a given user can + upload in a given time period. + + + An empty list means no limits are applied. + default: [] + items: + time_period: + type: "#/$defs/duration" + description: >- + The time period over which the limit applies. Required. + max_size: + type: "#/$defs/bytes" + description: >- + Amount of data that can be uploaded in the time period by the user. + Required. + examples: + - - time_period: 1h + max_size: 100M + - time_period: 1w + max_size: 500M max_image_pixels: $ref: "#/$defs/bytes" description: Maximum number of pixels that will be thumbnailed. diff --git a/synapse/config/repository.py b/synapse/config/repository.py index fc5a90c85a6..e6a5064c166 100644 --- a/synapse/config/repository.py +++ b/synapse/config/repository.py @@ -119,6 +119,15 @@ def parse_thumbnail_requirements( } +@attr.s(auto_attribs=True, slots=True, frozen=True) +class MediaUploadLimit: + """A limit on the amount of data a user can upload in a given time + period.""" + + max_bytes: int + time_period_ms: int + + class ContentRepositoryConfig(Config): section = "media" @@ -274,6 +283,13 @@ def read_config(self, config: JsonDict, **kwargs: Any) -> None: self.enable_authenticated_media = config.get("enable_authenticated_media", True) + self.media_upload_limits: List[MediaUploadLimit] = [] + for limit_config in config.get("media_upload_limits", []): + time_period_ms = self.parse_duration(limit_config["time_period"]) + max_bytes = self.parse_size(limit_config["max_size"]) + + self.media_upload_limits.append(MediaUploadLimit(max_bytes, time_period_ms)) + def generate_config_section(self, data_dir_path: str, **kwargs: Any) -> str: assert data_dir_path is not None media_store = os.path.join(data_dir_path, "media_store") diff --git a/synapse/handlers/sso.py b/synapse/handlers/sso.py index 07827cf95bd..10e3d37d4e2 100644 --- a/synapse/handlers/sso.py +++ b/synapse/handlers/sso.py @@ -824,7 +824,7 @@ def is_allowed_mime_type(content_type: str) -> bool: return True # store it in media repository - avatar_mxc_url = await self._media_repo.create_content( + avatar_mxc_url = await self._media_repo.create_or_update_content( media_type=headers[b"Content-Type"][0].decode("utf-8"), upload_name=upload_name, content=picture, diff --git a/synapse/media/media_repository.py b/synapse/media/media_repository.py index 18c5a8ecec4..8b8af050613 100644 --- a/synapse/media/media_repository.py +++ b/synapse/media/media_repository.py @@ -177,6 +177,13 @@ def __init__(self, hs: "HomeServer"): else: self.url_previewer = None + # We get the media upload limits and sort them in descending order of + # time period, so that we can apply some optimizations. + self.media_upload_limits = hs.config.media.media_upload_limits + self.media_upload_limits.sort( + key=lambda limit: limit.time_period_ms, reverse=True + ) + def _start_update_recently_accessed(self) -> Deferred: return run_as_background_process( "update_recently_accessed_media", self._update_recently_accessed @@ -285,63 +292,16 @@ async def verify_can_upload(self, media_id: str, auth_user: UserID) -> None: raise NotFoundError("Media ID has expired") @trace - async def update_content( - self, - media_id: str, - media_type: str, - upload_name: Optional[str], - content: IO, - content_length: int, - auth_user: UserID, - ) -> None: - """Update the content of the given media ID. - - Args: - media_id: The media ID to replace. - media_type: The content type of the file. - upload_name: The name of the file, if provided. - content: A file like object that is the content to store - content_length: The length of the content - auth_user: The user_id of the uploader - """ - file_info = FileInfo(server_name=None, file_id=media_id) - sha256reader = SHA256TransparentIOReader(content) - # This implements all of IO as it has a passthrough - fname = await self.media_storage.store_file(sha256reader.wrap(), file_info) - sha256 = sha256reader.hexdigest() - should_quarantine = await self.store.get_is_hash_quarantined(sha256) - logger.info("Stored local media in file %r", fname) - - if should_quarantine: - logger.warn( - "Media has been automatically quarantined as it matched existing quarantined media" - ) - - await self.store.update_local_media( - media_id=media_id, - media_type=media_type, - upload_name=upload_name, - media_length=content_length, - user_id=auth_user, - sha256=sha256, - quarantined_by="system" if should_quarantine else None, - ) - - try: - await self._generate_thumbnails(None, media_id, media_id, media_type) - except Exception as e: - logger.info("Failed to generate thumbnails: %s", e) - - @trace - async def create_content( + async def create_or_update_content( self, media_type: str, upload_name: Optional[str], content: IO, content_length: int, auth_user: UserID, + media_id: Optional[str] = None, ) -> MXCUri: - """Store uploaded content for a local user and return the mxc URL + """Create or update the content of the given media ID. Args: media_type: The content type of the file. @@ -349,16 +309,20 @@ async def create_content( content: A file like object that is the content to store content_length: The length of the content auth_user: The user_id of the uploader + media_id: The media ID to update if provided, otherwise creates + new media ID. Returns: The mxc url of the stored content """ - media_id = random_string(24) + is_new_media = media_id is None + if media_id is None: + media_id = random_string(24) file_info = FileInfo(server_name=None, file_id=media_id) - # This implements all of IO as it has a passthrough sha256reader = SHA256TransparentIOReader(content) + # This implements all of IO as it has a passthrough fname = await self.media_storage.store_file(sha256reader.wrap(), file_info) sha256 = sha256reader.hexdigest() should_quarantine = await self.store.get_is_hash_quarantined(sha256) @@ -370,16 +334,56 @@ async def create_content( "Media has been automatically quarantined as it matched existing quarantined media" ) - await self.store.store_local_media( - media_id=media_id, - media_type=media_type, - time_now_ms=self.clock.time_msec(), - upload_name=upload_name, - media_length=content_length, - user_id=auth_user, - sha256=sha256, - quarantined_by="system" if should_quarantine else None, - ) + # Check that the user has not exceeded any of the media upload limits. + + # This is the total size of media uploaded by the user in the last + # `time_period_ms` milliseconds, or None if we haven't checked yet. + uploaded_media_size: Optional[int] = None + + # Note: the media upload limits are sorted so larger time periods are + # first. + for limit in self.media_upload_limits: + # We only need to check the amount of media uploaded by the user in + # this latest (smaller) time period if the amount of media uploaded + # in a previous (larger) time period is above the limit. + # + # This optimization means that in the common case where the user + # hasn't uploaded much media, we only need to query the database + # once. + if ( + uploaded_media_size is None + or uploaded_media_size + content_length > limit.max_bytes + ): + uploaded_media_size = await self.store.get_media_uploaded_size_for_user( + user_id=auth_user.to_string(), time_period_ms=limit.time_period_ms + ) + + if uploaded_media_size + content_length > limit.max_bytes: + raise SynapseError( + 400, "Media upload limit exceeded", Codes.RESOURCE_LIMIT_EXCEEDED + ) + + if is_new_media: + await self.store.store_local_media( + media_id=media_id, + media_type=media_type, + time_now_ms=self.clock.time_msec(), + upload_name=upload_name, + media_length=content_length, + user_id=auth_user, + sha256=sha256, + quarantined_by="system" if should_quarantine else None, + ) + else: + await self.store.update_local_media( + media_id=media_id, + media_type=media_type, + upload_name=upload_name, + media_length=content_length, + user_id=auth_user, + sha256=sha256, + quarantined_by="system" if should_quarantine else None, + ) try: await self._generate_thumbnails(None, media_id, media_id, media_type) diff --git a/synapse/rest/media/upload_resource.py b/synapse/rest/media/upload_resource.py index 572f7897fd8..74d82805824 100644 --- a/synapse/rest/media/upload_resource.py +++ b/synapse/rest/media/upload_resource.py @@ -120,7 +120,7 @@ async def on_POST(self, request: SynapseRequest) -> None: try: content: IO = request.content # type: ignore - content_uri = await self.media_repo.create_content( + content_uri = await self.media_repo.create_or_update_content( media_type, upload_name, content, content_length, requester.user ) except SpamMediaException: @@ -170,13 +170,13 @@ async def on_PUT( try: content: IO = request.content # type: ignore - await self.media_repo.update_content( - media_id, + await self.media_repo.create_or_update_content( media_type, upload_name, content, content_length, requester.user, + media_id=media_id, ) except SpamMediaException: # For uploading of media we want to respond with a 400, instead of diff --git a/synapse/storage/databases/main/media_repository.py b/synapse/storage/databases/main/media_repository.py index 04866524e3c..f726846e57f 100644 --- a/synapse/storage/databases/main/media_repository.py +++ b/synapse/storage/databases/main/media_repository.py @@ -1034,3 +1034,39 @@ def get_matching_media_txn( "local_media_repository", sha256, ) + + async def get_media_uploaded_size_for_user( + self, user_id: str, time_period_ms: int + ) -> int: + """Get the total size of media uploaded by a user in the last + time_period_ms milliseconds. + + Args: + user_id: The user ID to check. + time_period_ms: The time period in milliseconds to consider. + + Returns: + The total size of media uploaded by the user in bytes. + """ + + sql = """ + SELECT COALESCE(SUM(media_length), 0) + FROM local_media_repository + WHERE user_id = ? AND created_ts > ? + """ + + def _get_media_uploaded_size_for_user_txn( + txn: LoggingTransaction, + ) -> int: + # Calculate the timestamp for the start of the time period + start_ts = self._clock.time_msec() - time_period_ms + txn.execute(sql, (user_id, start_ts)) + row = txn.fetchone() + if row is None: + return 0 + return row[0] + + return await self.db_pool.runInteraction( + "get_media_uploaded_size_for_user", + _get_media_uploaded_size_for_user_txn, + ) diff --git a/tests/federation/test_federation_media.py b/tests/federation/test_federation_media.py index 9c92003ce54..cd4905239f0 100644 --- a/tests/federation/test_federation_media.py +++ b/tests/federation/test_federation_media.py @@ -67,7 +67,7 @@ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: def test_file_download(self) -> None: content = io.BytesIO(b"file_to_stream") content_uri = self.get_success( - self.media_repo.create_content( + self.media_repo.create_or_update_content( "text/plain", "test_upload", content, @@ -110,7 +110,7 @@ def test_file_download(self) -> None: content = io.BytesIO(SMALL_PNG) content_uri = self.get_success( - self.media_repo.create_content( + self.media_repo.create_or_update_content( "image/png", "test_png_upload", content, @@ -152,7 +152,7 @@ def test_federation_etag(self) -> None: content = io.BytesIO(b"file_to_stream") content_uri = self.get_success( - self.media_repo.create_content( + self.media_repo.create_or_update_content( "text/plain", "test_upload", content, @@ -215,7 +215,7 @@ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: def test_thumbnail_download_scaled(self) -> None: content = io.BytesIO(small_png.data) content_uri = self.get_success( - self.media_repo.create_content( + self.media_repo.create_or_update_content( "image/png", "test_png_thumbnail", content, @@ -255,7 +255,7 @@ def test_thumbnail_download_scaled(self) -> None: def test_thumbnail_download_cropped(self) -> None: content = io.BytesIO(small_png.data) content_uri = self.get_success( - self.media_repo.create_content( + self.media_repo.create_or_update_content( "image/png", "test_png_thumbnail", content, diff --git a/tests/media/test_media_retention.py b/tests/media/test_media_retention.py index d8f4f57c8c9..89cf61430a8 100644 --- a/tests/media/test_media_retention.py +++ b/tests/media/test_media_retention.py @@ -78,7 +78,7 @@ def _create_media_and_set_attributes( # If the meda random_content = bytes(random_string(24), "utf-8") mxc_uri: MXCUri = self.get_success( - media_repository.create_content( + media_repository.create_or_update_content( media_type="text/plain", upload_name=None, content=io.BytesIO(random_content), diff --git a/tests/rest/client/test_media.py b/tests/rest/client/test_media.py index 6ee761e44b9..7aa1f2406cc 100644 --- a/tests/rest/client/test_media.py +++ b/tests/rest/client/test_media.py @@ -1952,7 +1952,7 @@ async def _send_request(*args: Any, **kwargs: Any) -> IResponse: def test_file_download(self) -> None: content = io.BytesIO(b"file_to_stream") content_uri = self.get_success( - self.repo.create_content( + self.repo.create_or_update_content( "text/plain", "test_upload", content, @@ -2846,3 +2846,124 @@ def _check_caching(self, path: str) -> None: custom_headers=[("If-None-Match", etag)], ) self.assertEqual(channel3.code, 404) + + +class MediaUploadLimits(unittest.HomeserverTestCase): + """ + This test case simulates a homeserver with media upload limits configured. + """ + + servlets = [ + media.register_servlets, + login.register_servlets, + admin.register_servlets, + ] + + def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: + config = self.default_config() + + self.storage_path = self.mktemp() + self.media_store_path = self.mktemp() + os.mkdir(self.storage_path) + os.mkdir(self.media_store_path) + config["media_store_path"] = self.media_store_path + + provider_config = { + "module": "synapse.media.storage_provider.FileStorageProviderBackend", + "store_local": True, + "store_synchronous": False, + "store_remote": True, + "config": {"directory": self.storage_path}, + } + + config["media_storage_providers"] = [provider_config] + + # These are the limits that we are testing + config["media_upload_limits"] = [ + {"time_period": "1d", "max_size": "1K"}, + {"time_period": "1w", "max_size": "3K"}, + ] + + return self.setup_test_homeserver(config=config) + + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: + self.repo = hs.get_media_repository() + self.client = hs.get_federation_http_client() + self.store = hs.get_datastores().main + self.user = self.register_user("user", "pass") + self.tok = self.login("user", "pass") + + def create_resource_dict(self) -> Dict[str, Resource]: + resources = super().create_resource_dict() + resources["/_matrix/media"] = self.hs.get_media_repository_resource() + return resources + + def upload_media(self, size: int) -> FakeChannel: + """Helper to upload media of a given size.""" + return self.make_request( + "POST", + "/_matrix/media/v3/upload", + content=b"0" * size, + access_token=self.tok, + shorthand=False, + content_type=b"text/plain", + custom_headers=[("Content-Length", str(size))], + ) + + def test_upload_under_limit(self) -> None: + """Test that uploading media under the limit works.""" + channel = self.upload_media(67) + self.assertEqual(channel.code, 200) + + def test_over_day_limit(self) -> None: + """Test that uploading media over the daily limit fails.""" + channel = self.upload_media(500) + self.assertEqual(channel.code, 200) + + channel = self.upload_media(800) + self.assertEqual(channel.code, 400) + + def test_under_daily_limit(self) -> None: + """Test that uploading media under the daily limit fails.""" + channel = self.upload_media(500) + self.assertEqual(channel.code, 200) + + self.reactor.advance(60 * 60 * 24) # Advance by one day + + # This will succeed as the daily limit has reset + channel = self.upload_media(800) + self.assertEqual(channel.code, 200) + + self.reactor.advance(60 * 60 * 24) # Advance by one day + + # ... and again + channel = self.upload_media(800) + self.assertEqual(channel.code, 200) + + def test_over_weekly_limit(self) -> None: + """Test that uploading media over the weekly limit fails.""" + channel = self.upload_media(900) + self.assertEqual(channel.code, 200) + + self.reactor.advance(60 * 60 * 24) # Advance by one day + + channel = self.upload_media(900) + self.assertEqual(channel.code, 200) + + self.reactor.advance(2 * 60 * 60 * 24) # Advance by one day + + channel = self.upload_media(900) + self.assertEqual(channel.code, 200) + + self.reactor.advance(2 * 60 * 60 * 24) # Advance by one day + + # This will fail as the weekly limit has been exceeded + channel = self.upload_media(900) + self.assertEqual(channel.code, 400) + + # Reset the weekly limit by advancing a week + self.reactor.advance(7 * 60 * 60 * 24) # Advance by 7 days + + # This will succeed as the weekly limit has reset + channel = self.upload_media(900) + self.assertEqual(channel.code, 200)