Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 27 additions & 2 deletions .github/workflows/lint_and_test.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -40,13 +40,13 @@ jobs:
- name: mypy
run: mypy --install-types --non-interactive ./ --cache-dir=.mypy_cache/

unit_test:
unit_test_non_streaming:
runs-on: ubuntu-latest
timeout-minutes: 20
strategy:
fail-fast: false
matrix:
python-version: ["3.8", "3.9", "3.10", "3.11"]
python-version: ["3.8", "3.9"]
steps:
- uses: actions/checkout@v4
- name: Set up Python ${{ matrix.python-version }}
Expand All @@ -63,6 +63,31 @@ jobs:
pip install -e .
- name: pytest_unit
run: pytest -s -v tests/test_models.py

unit_test_streaming:
runs-on: ubuntu-latest
timeout-minutes: 20
strategy:
fail-fast: false
matrix:
python-version: ["3.10", "3.11"]
steps:
- uses: actions/checkout@v4
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v5
with:
python-version: ${{ matrix.python-version }}
cache: "pip"
- name: Install dependencies
run: |
sudo apt-get update
sudo apt-get install -y libsndfile1 ffmpeg
python -m pip install --upgrade pip
pip install -r requirements-dev.txt
pip install torchcodec
pip install -e .
- name: pytest_unit
run: pytest -s -v tests/test_models.py

unit_test_old_torch:
runs-on: ubuntu-latest
Expand Down
11 changes: 10 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,16 @@ All notable changes to AudioSeal are documented in this file.

The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

## [0.1.5 - 0.1.7] - 2025-04-29

## [0.2.0] - 2025-12-09

- Add new models with streaming support (`audioseal_wm_streaming`, `audioseal_detector_streaming`)
- Refactor code to make the functions torchscriptable
- Deprecate the internal resampling during watermark generation and detection
- Fix bugs in load user-defined secret messages with mis-matching batch size
- Update example notebooks

## [0.1.5 - 0.1.8] - 2025-04-29

- Fix bugs in loading model in new PyTorch (2.6+)
- Add support for loading from other HF spaces
Expand Down
111 changes: 82 additions & 29 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,11 +1,9 @@
# :loud_sound: AudioSeal: Proactive Localized Watermarking
# :loud_sound: AudioSeal: Efficient Localized Audio Watermarking

<a href="https://www.python.org/"><img alt="Python" src="https://img.shields.io/badge/-Python 3.8+-blue?style=for-the-badge&logo=python&logoColor=white"></a>
<a href="https://black.readthedocs.io/en/stable/"><img alt="Code style: black" src="https://img.shields.io/badge/code%20style-black-black.svg?style=for-the-badge&labelColor=gray"></a>

This repo contains the Inference code for **AudioSeal**, a method for speech localized watermarking, with state-of-the-art robustness and detector speed (training code coming soon).

To learn more, check out our [paper](https://arxiv.org/abs/2401.17264).
This repo contains the official implementation for **AudioSeal**, a method for efficient audio watermarking, with state-of-the-art robustness and detector speed.

# :rocket: Quick Links:

Expand All @@ -20,6 +18,7 @@ To learn more, check out our [paper](https://arxiv.org/abs/2401.17264).

# :sparkles: Key Updates:

- 2024-12-12: AudioSeal 0.2 is out, with streaming support and other improvement
- 2024-06-17: Training code is now available. Check the [instruction](./docs/TRAINING.md)!!!
- 2024-05-31: Our paper gets accepted at ICML'24 :)
- 2024-04-02: We have updated our license to full MIT license (including the license for the model weights) ! Now you can use AudioSeal in commercial application too!
Expand All @@ -28,23 +27,23 @@ To learn more, check out our [paper](https://arxiv.org/abs/2401.17264).

# :book: Abstract

**AudioSeal** introduces a breakthrough in **proactive, localized watermarking** for speech. It jointly trains two components: a **generator** that embeds an imperceptible watermark into audio and a **detector** that identifies watermark fragments in long or edited audio files.
**AudioSeal** introduces a novel audio watermarking using **ocalized watermarking** and a novel perceptual loss. It jointly trains two components: a **generator** that embeds an imperceptible watermark into audio and a **detector** that identifies watermark fragments in long or edited audio files.

- **Key Features:**
- **Localized watermarking** at the sample level (1/16,000 of a second).
- **Localized watermarking** at the sample level (1/16,000 of a second). AudioSeal works well with other sampling rates as well (24 khZ, 44.5 kHz, 48 kHz)
- Minimal impact on audio quality.
- **Robust** against various audio edits like compression, re-encoding, and noise addition.
- **Fast, single-pass detection** designed to surpass existing models significantly in speed — achieving detection up to **two orders of magnitude faster**, making it ideal for large-scale and real-time applications.


# :gear: Installation
# :mate: Installation

### Requirements:
- Python >= 3.8
- Python >= 3.8 (>= 3.10 for streaming support)
- Pytorch >= 1.13.0
- [Omegaconf](https://omegaconf.readthedocs.io/)
- [Julius](https://pypi.org/project/julius/)
- [Numpy](https://pypi.org/project/numpy/)
- [einops](https://github.com/arogozhnikov/einops) (for streaming support)

### Install from PyPI:
```
Expand Down Expand Up @@ -80,16 +79,18 @@ from audioseal import AudioSeal

# model name corresponds to the YAML card file name found in audioseal/cards
model = AudioSeal.load_generator("audioseal_wm_16bits")
model.eval()

# Other way is to load directly from the checkpoint
# model = Watermarker.from_pretrained(checkpoint_path, device = wav.device)

# a torch tensor of shape (batch, channels, samples) and a sample rate
# It is important to process the audio to the same sample rate as the model
# expects. In our case, we support 16khz audio
wav, sr = ..., 16000
# expects. The default AudioSeal should work well with 16kHz and 24kHz, and
# in the case of 48 khZ, it should work well for most speech audios
wav = [load audio wav into a tensor of BatchxChannelxTime]

watermark = model.get_watermark(wav, sr)
watermark = model.get_watermark(wav)

# Optional: you can add a 16-bit message to embed in the watermark
# msg = torch.randint(0, 2, (wav.shape(0), model.msg_processor.nbits), device=wav.device)
Expand All @@ -100,14 +101,14 @@ watermarked_audio = wav + watermark
detector = AudioSeal.load_detector("audioseal_detector_16bits")

# To detect the messages in the high-level.
result, message = detector.detect_watermark(watermarked_audio, sr)
result, message = detector.detect_watermark(watermarked_audio)

print(result) # result is a float number indicating the probability of the audio being watermarked,
print(message) # message is a binary vector of 16 bits


# To detect the messages in the low-level.
result, message = detector(watermarked_audio, sr)
result, message = detector(watermarked_audio)

# result is a tensor of size batch x 2 x frames, indicating the probability (positive and negative) of watermarking for each frame
# A watermarked audio should have result[:, 1, :] > 0.5
Expand All @@ -118,16 +119,60 @@ print(result[:, 1 , :])
print(message)
```

# :rocket: Train your own watermarking model
# :abacus: Streaming support

Starting AudioSeal 0.2, you can run the watermarking over the stream of audio signals. The API is `model.streaming(batch_size),
which will enable the convolutional cache during the watermark generation. Ensure to put this within context, so the cache is
safely cleaned after the session:

```python

model = AudioSeal.load_generator("audioseal_wm_streaming")
model.eval()

audio = [audio chunks]
streaming_watermarked_audio = []

with model.streaming(batch_size=1):

# Watermark each incoming chunk of the streaming audio
for chunk in audio:
watermarked_chunk = model(chunk, sample_rate=sr, message=secret_mesage, alpha=1)
streaming_watermarked_audio.append(watermarked_chunk)

streaming_watermarked_audio = torch.cat(streaming_watermarked_audio, dim=1)


# You can detect a chunk of watermarked output, or the whole audio:

detector = AudioSeal.load_detector("audioseal_detector_streaming")
detector.eval()

wm_chunk = 100
partial_result, _ = detector.detect_watermark(streaming_watermarked_audio[:, :, :wm_chunk])


full_result, _ = detector.detect_watermark(streaming_watermarked_audio)

```
See [example notebook](examples/Getting_started.ipynb) for full details.


# :brain: Train your own watermarking model

See [here](./docs/TRAINING.md) for details on how to train your own Watermarking model.


# See Also

Interested in training your own watermarking model? Check out our [training documentation](./docs/TRAINING.md) to get started.
The team also develops other open-source watermarking solutions:
- [WMAR](https://github.com/facebookresearch/wmar): Autoregressive watermarking models for images
- [Video Seal](https://github.com/facebookresearch/videoseal): Open and efficient video watermarking
- [WAM](https://github.com/facebookresearch/watermark-anything): Watermark Any Images with Localization

# :wave: Want to contribute?

We welcome pull requests with improvements or suggestions.
If you wish to report an issue or propose an enhancement but are unsure how to implement it, feel free to create a GitHub issue.

# :bug: Troubleshooting
# 🎮 Troubleshooting

- If you encounter the error `ValueError: not enough values to unpack (expected 3, got 2)`, this is because we expect a batch of audio tensors as inputs. Add one
dummy batch dimension to your input (e.g. `wav.unsqueeze(0)`, see [example notebook for getting started](examples/Getting_started.ipynb)).
Expand All @@ -139,19 +184,27 @@ and re-run again.
- If you use torchaudio to handle your audios and encounter the error `Couldn't find appropriate backend to handle uri ...`, this is due to newer version of
torchaudio does not handle the default backend well. Either downgrade your torchaudio to `2.1.0` or earlier, or install `soundfile` as your audio backend.

# :page_with_curl: License
# :heart: Acknowledgements

We borrow the code with some adaptations from the following repos:
- [AudioCraft](https://github.com/facebookresearch/audiocraft/) in `libs/audiocraft/`.
- [Moshi](https://github.com/kyutai-labs/moshi/) in `libs/moshi/`.


# :handshake: Contributions

We welcome Pull Requests with improvements or suggestions.
If you want to flag an issue or propose an improvement, but don't know how to realize it, create a GitHub Issue.


# 🧾 License

- The code in this repository is licensed under the MIT license as detailed in the [LICENSE file](LICENSE). This license permits reuse, modification, and distribution of the software, as long as the original license is included.
- The code in this repository is released under the MIT license as found in the [LICENSE file](LICENSE).

# :star2: Maintainers:
- [Tuan Tran](https://github.com/antoine-tran)
- [Hady Elsahar](https://github.com/hadyelsahar)
- [Pierre Fernandez](https://github.com/pierrefdz)
- [Robin San Roman](https://github.com/robinsrm)

# :scroll: Citation
# ✍️ Citation

If you find this repository useful, please consider giving it a star :star: and citing our work:
If you find this repository useful, please consider giving a star :star: and please cite as:

```
@article{sanroman2024proactive,
Expand Down
228 changes: 192 additions & 36 deletions examples/Getting_started.ipynb

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,8 @@ classifiers=[
dependencies = [
"numpy",
"omegaconf",
"julius",
"torch>=1.13.0",
"einops; python_version >= '3.10'",
]

[project.urls]
Expand Down
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
numpy
omegaconf
julius
torch>=1.13.0
einops; python_version >= "3.10"
2 changes: 1 addition & 1 deletion src/audioseal/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@

"""

__version__ = "0.1.8"
__version__ = "0.2.0"


from audioseal import builder
Expand Down
58 changes: 49 additions & 9 deletions src/audioseal/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,37 @@
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.

import sys
from dataclasses import asdict, dataclass, field, is_dataclass
from typing import Any, Dict, List, Optional

from omegaconf import DictConfig, OmegaConf
from torch import device, dtype
from typing_extensions import TypeAlias

from audioseal.libs import audiocraft
from audioseal.models import AudioSealDetector, AudioSealWM, MsgProcessor
from audioseal.models import (
AudioSealDetector,
AudioSealWM,
MsgProcessor,
NormalizationProcessor,
)

# We use different SEANet implementations based on Python version.
# For 3.10 and above: Moshi's SEANetEncoder and SEANetDecoder.
# For 3.9 and below: Audiocraft's SEANetEncoder and SEANetDecoder..
if sys.version_info >= (3, 10):
from audioseal.libs.moshi.modules.seanet import (
SEANetDecoder,
SEANetEncoder,
SEANetEncoderKeepDimension,
)
else:
from audioseal.libs.audiocraft.modules.seanet import (
SEANetDecoder,
SEANetEncoder,
SEANetEncoderKeepDimension,
)


Device: TypeAlias = device

Expand Down Expand Up @@ -63,13 +85,15 @@ class AudioSealWMConfig:
nbits: int
seanet: SEANetConfig
decoder: DecoderConfig
normalizer: bool = False


@dataclass
class AudioSealDetectorConfig:
nbits: int
seanet: SEANetConfig
detector: DetectorConfig = field(default_factory=lambda: DetectorConfig())
normalizer: bool = False


def as_dict(obj: Any) -> Dict[str, Any]:
Expand All @@ -93,17 +117,25 @@ def create_generator(

# Currently the encoder hparams are the same as
# SEANet, but this can be changed in the future.
encoder = audiocraft.modules.SEANetEncoder(**as_dict(config.seanet))
encoder = encoder.to(device=device, dtype=dtype)
seanet_config = config.seanet

encoder_config = as_dict(seanet_config)
decoder_config = {**as_dict(config.seanet), **as_dict(config.decoder)}
decoder = audiocraft.modules.SEANetDecoder(**as_dict(decoder_config))

encoder = SEANetEncoder(**encoder_config)
encoder = encoder.to(device=device, dtype=dtype)

decoder = SEANetDecoder(**decoder_config)
decoder = decoder.to(device=device, dtype=dtype)

msgprocessor = MsgProcessor(nbits=config.nbits, hidden_size=config.seanet.dimension)
msgprocessor = MsgProcessor(
nbits=config.nbits, hidden_size=seanet_config.dimension)
msgprocessor = msgprocessor.to(device=device, dtype=dtype)

return AudioSealWM(encoder=encoder, decoder=decoder, msg_processor=msgprocessor)
normalizer = NormalizationProcessor() if getattr(
config, "normalizer", False) else None

return AudioSealWM(encoder=encoder, decoder=decoder, msg_processor=msgprocessor, normalizer=normalizer)


def create_detector(
Expand All @@ -112,7 +144,15 @@ def create_detector(
device: Optional[Device] = None,
dtype: Optional[DataType] = None,
) -> AudioSealDetector:
detector_config = {**as_dict(config.seanet), **as_dict(config.detector)}
detector = AudioSealDetector(nbits=config.nbits, **detector_config)

_detector = {"output_dim": 32}
detector_config = {**as_dict(config.seanet), **_detector}

encoder = SEANetEncoderKeepDimension(**detector_config)
normalizer = NormalizationProcessor() if getattr(
config, "normalizer", False) else None

detector = AudioSealDetector(
encoder=encoder, normalizer=normalizer, nbits=config.nbits)
detector = detector.to(device=device, dtype=dtype)
return detector
Loading