diff --git a/.github/workflows/tests-studio.yml b/.github/workflows/tests-studio.yml index 923731fef..76358ecca 100644 --- a/.github/workflows/tests-studio.yml +++ b/.github/workflows/tests-studio.yml @@ -75,6 +75,9 @@ jobs: path: './backend/datachain' fetch-depth: 0 + - name: Set up FFmpeg + uses: AnimMouse/setup-ffmpeg@v1 + - name: Set up Python ${{ matrix.pyv }} uses: actions/setup-python@v5 with: diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 3b96a12fd..6a807e29f 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -78,6 +78,9 @@ jobs: fetch-depth: 0 ref: ${{ github.event.pull_request.head.sha || github.ref }} + - name: Set up FFmpeg + uses: AnimMouse/setup-ffmpeg@v1 + - name: Set up Python ${{ matrix.pyv }} uses: actions/setup-python@v5 with: diff --git a/pyproject.toml b/pyproject.toml index f0d74680b..aad9e61d9 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -80,8 +80,16 @@ hf = [ "numba>=0.60.0", "datasets[audio,vision]>=2.21.0" ] +video = [ + # Use 'av<14' because of incompatibility with imageio + # See https://github.com/PyAV-Org/PyAV/discussions/1700 + "av<14", + "ffmpeg-python", + "imageio[ffmpeg]", + "opencv-python" +] tests = [ - "datachain[torch,remote,vector,hf]", + "datachain[torch,remote,vector,hf,video]", "pytest>=8,<9", "pytest-sugar>=0.9.6", "pytest-cov>=4.1.0", diff --git a/src/datachain/__init__.py b/src/datachain/__init__.py index e8bbc00bf..659b2ce5d 100644 --- a/src/datachain/__init__.py +++ b/src/datachain/__init__.py @@ -4,9 +4,14 @@ ArrowRow, File, FileError, + Image, ImageFile, TarVFile, TextFile, + Video, + VideoFile, + VideoFragment, + VideoFrame, ) from datachain.lib.model_store import ModelStore from datachain.lib.udf import Aggregator, Generator, Mapper @@ -27,6 +32,7 @@ "File", "FileError", "Generator", + "Image", "ImageFile", "Mapper", "ModelStore", @@ -34,6 +40,10 @@ "Sys", "TarVFile", "TextFile", + "Video", + "VideoFile", + "VideoFragment", + "VideoFrame", "is_chain_type", "metrics", "param", diff --git a/src/datachain/lib/file.py b/src/datachain/lib/file.py index 599fa667e..1f4431371 100644 --- a/src/datachain/lib/file.py +++ b/src/datachain/lib/file.py @@ -17,7 +17,7 @@ from urllib.request import url2pathname from fsspec.callbacks import DEFAULT_CALLBACK, Callback -from PIL import Image +from PIL import Image as PilImage from pydantic import Field, field_validator from datachain.client.fileslice import FileSlice @@ -27,6 +27,7 @@ from datachain.utils import TIME_ZERO if TYPE_CHECKING: + from numpy import ndarray from typing_extensions import Self from datachain.catalog import Catalog @@ -40,7 +41,7 @@ # how to create file path when exporting ExportPlacement = Literal["filename", "etag", "fullpath", "checksum"] -FileType = Literal["binary", "text", "image"] +FileType = Literal["binary", "text", "image", "video"] class VFileError(DataChainError): @@ -193,7 +194,7 @@ def __init__(self, **kwargs): @classmethod def upload( cls, data: bytes, path: str, catalog: Optional["Catalog"] = None - ) -> "File": + ) -> "Self": if catalog is None: from datachain.catalog.loader import get_catalog @@ -203,6 +204,8 @@ def upload( client = catalog.get_client(parent) file = client.upload(data, name) + if not isinstance(file, cls): + file = cls(**file.model_dump()) file._set_stream(catalog) return file @@ -486,13 +489,281 @@ class ImageFile(File): def read(self): """Returns `PIL.Image.Image` object.""" fobj = super().read() - return Image.open(BytesIO(fobj)) + return PilImage.open(BytesIO(fobj)) def save(self, destination: str): """Writes it's content to destination""" self.read().save(destination) +class Image(DataModel): + """ + A data model representing metadata for an image file. + + Attributes: + width (int): The width of the image in pixels. Defaults to -1 if unknown. + height (int): The height of the image in pixels. Defaults to -1 if unknown. + format (str): The format of the image file (e.g., 'jpg', 'png'). + Defaults to an empty string. + """ + + width: int = Field(default=-1) + height: int = Field(default=-1) + format: str = Field(default="") + + +class VideoFile(File): + """ + A data model for handling video files. + + This model inherits from the `File` model and provides additional functionality + for reading video files, extracting video frames, and splitting videos into + fragments. + """ + + def get_info(self) -> "Video": + """ + Retrieves metadata and information about the video file. + + Returns: + Video: A Model containing video metadata such as duration, + resolution, frame rate, and codec details. + """ + from .video import video_info + + return video_info(self) + + def get_frame(self, frame: int) -> "VideoFrame": + """ + Returns a specific video frame by its frame number. + + Args: + frame (int): The frame number to read. + + Returns: + VideoFrame: Video frame model. + """ + if frame < 0: + raise ValueError("frame must be a non-negative integer") + + return VideoFrame(video=self, frame=frame) + + def get_frames( + self, + start: int = 0, + end: Optional[int] = None, + step: int = 1, + ) -> "Iterator[VideoFrame]": + """ + Returns video frames from the specified range in the video. + + Args: + start (int): The starting frame number (default: 0). + end (int, optional): The ending frame number (exclusive). If None, + frames are read until the end of the video + (default: None). + step (int): The interval between frames to read (default: 1). + + Returns: + Iterator[VideoFrame]: An iterator yielding video frames. + + Note: + If end is not specified, number of frames will be taken from the video file, + this means video file needs to be downloaded. + """ + from .video import validate_frame_range + + start, end, step = validate_frame_range(self, start, end, step) + + for frame in range(start, end, step): + yield self.get_frame(frame) + + def get_fragment(self, start: float, end: float) -> "VideoFragment": + """ + Returns a video fragment from the specified time range. + + Args: + start (float): The start time of the fragment in seconds. + end (float): The end time of the fragment in seconds. + + Returns: + VideoFragment: A Model representing the video fragment. + """ + if start < 0 or end < 0 or start >= end: + raise ValueError(f"Invalid time range: ({start:.3f}, {end:.3f})") + + return VideoFragment(video=self, start=start, end=end) + + def get_fragments( + self, + duration: float, + start: float = 0, + end: Optional[float] = None, + ) -> "Iterator[VideoFragment]": + """ + Splits the video into multiple fragments of a specified duration. + + Args: + duration (float): The duration of each video fragment in seconds. + start (float): The starting time in seconds (default: 0). + end (float, optional): The ending time in seconds. If None, the entire + remaining video is processed (default: None). + + Returns: + Iterator[VideoFragment]: An iterator yielding video fragments. + + Note: + If end is not specified, number of frames will be taken from the video file, + this means video file needs to be downloaded. + """ + if duration <= 0: + raise ValueError("duration must be a positive float") + if start < 0: + raise ValueError("start must be a non-negative float") + + if end is None: + end = self.get_info().duration + + if end < 0: + raise ValueError("end must be a non-negative float") + if start >= end: + raise ValueError("start must be less than end") + + while start < end: + yield self.get_fragment(start, min(start + duration, end)) + start += duration + + +class VideoFrame(DataModel): + """ + A data model for representing a video frame. + + This model inherits from the `VideoFile` model and adds a `frame` attribute, + which represents a specific frame within a video file. It allows access + to individual frames and provides functionality for reading and saving + video frames as image files. + + Attributes: + video (VideoFile): The video file containing the video frame. + frame (int): The frame number referencing a specific frame in the video file. + """ + + video: VideoFile + frame: int + + def get_np(self) -> "ndarray": + """ + Returns a video frame from the video file as a NumPy array. + + Returns: + ndarray: A NumPy array representing the video frame, + in the shape (height, width, channels). + """ + from .video import video_frame_np + + return video_frame_np(self.video, self.frame) + + def read_bytes(self, format: str = "jpg") -> bytes: + """ + Returns a video frame from the video file as image bytes. + + Args: + format (str): The desired image format (e.g., 'jpg', 'png'). + Defaults to 'jpg'. + + Returns: + bytes: The encoded video frame as image bytes. + """ + from .video import video_frame_bytes + + return video_frame_bytes(self.video, self.frame, format) + + def save(self, output: str, format: str = "jpg") -> "ImageFile": + """ + Saves the current video frame as an image file. + + If `output` is a remote path, the image file will be uploaded to remote storage. + + Args: + output (str): The destination path, which can be a local file path + or a remote URL. + format (str): The image format (e.g., 'jpg', 'png'). Defaults to 'jpg'. + + Returns: + ImageFile: A Model representing the saved image file. + """ + from .video import save_video_frame + + return save_video_frame(self.video, self.frame, output, format) + + +class VideoFragment(DataModel): + """ + A data model for representing a video fragment. + + This model inherits from the `VideoFile` model and adds `start` + and `end` attributes, which represent a specific fragment within a video file. + It allows access to individual fragments and provides functionality for reading + and saving video fragments as separate video files. + + Attributes: + video (VideoFile): The video file containing the video fragment. + start (float): The starting time of the video fragment in seconds. + end (float): The ending time of the video fragment in seconds. + """ + + video: VideoFile + start: float + end: float + + def save(self, output: str, format: Optional[str] = None) -> "VideoFile": + """ + Saves the video fragment as a new video file. + + If `output` is a remote path, the video file will be uploaded to remote storage. + + Args: + output (str): The destination path, which can be a local file path + or a remote URL. + format (str, optional): The output video format (e.g., 'mp4', 'avi'). + If None, the format is inferred from the + file extension. + + Returns: + VideoFile: A Model representing the saved video file. + """ + from .video import save_video_fragment + + return save_video_fragment(self.video, self.start, self.end, output, format) + + +class Video(DataModel): + """ + A data model representing metadata for a video file. + + Attributes: + width (int): The width of the video in pixels. Defaults to -1 if unknown. + height (int): The height of the video in pixels. Defaults to -1 if unknown. + fps (float): The frame rate of the video (frames per second). + Defaults to -1.0 if unknown. + duration (float): The total duration of the video in seconds. + Defaults to -1.0 if unknown. + frames (int): The total number of frames in the video. + Defaults to -1 if unknown. + format (str): The format of the video file (e.g., 'mp4', 'avi'). + Defaults to an empty string. + codec (str): The codec used for encoding the video. Defaults to an empty string. + """ + + width: int = Field(default=-1) + height: int = Field(default=-1) + fps: float = Field(default=-1.0) + duration: float = Field(default=-1.0) + frames: int = Field(default=-1) + format: str = Field(default="") + codec: str = Field(default="") + + class ArrowRow(DataModel): """`DataModel` for reading row from Arrow-supported file.""" @@ -528,5 +799,7 @@ def get_file_type(type_: FileType = "binary") -> type[File]: file = TextFile elif type_ == "image": file = ImageFile # type: ignore[assignment] + elif type_ == "video": + file = VideoFile return file diff --git a/src/datachain/lib/hf.py b/src/datachain/lib/hf.py index 66f4ee4fb..2e31c7f84 100644 --- a/src/datachain/lib/hf.py +++ b/src/datachain/lib/hf.py @@ -20,7 +20,7 @@ except ImportError as exc: raise ImportError( - "Missing dependencies for huggingface datasets:\n" + "Missing dependencies for huggingface datasets.\n" "To install run:\n\n" " pip install 'datachain[hf]'\n" ) from exc diff --git a/src/datachain/lib/vfile.py b/src/datachain/lib/vfile.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/src/datachain/lib/video.py b/src/datachain/lib/video.py new file mode 100644 index 000000000..d007fd30e --- /dev/null +++ b/src/datachain/lib/video.py @@ -0,0 +1,223 @@ +import posixpath +import shutil +import tempfile +from typing import Optional + +from numpy import ndarray + +from datachain.lib.file import FileError, ImageFile, Video, VideoFile + +try: + import ffmpeg + import imageio.v3 as iio +except ImportError as exc: + raise ImportError( + "Missing dependencies for processing video.\n" + "To install run:\n\n" + " pip install 'datachain[video]'\n" + ) from exc + + +def video_info(file: VideoFile) -> Video: + """ + Returns video file information. + + Args: + file (VideoFile): Video file object. + + Returns: + Video: Video file information. + """ + if not (file_path := file.get_local_path()): + file.ensure_cached() + file_path = file.get_local_path() + if not file_path: + raise FileError(file, "unable to download video file") + + try: + probe = ffmpeg.probe(file_path) + except Exception as exc: + raise FileError(file, "unable to extract metadata from video file") from exc + + all_streams = probe.get("streams") + video_format = probe.get("format") + if not all_streams or not video_format: + raise FileError(file, "unable to extract metadata from video file") + + video_streams = [s for s in all_streams if s["codec_type"] == "video"] + if len(video_streams) == 0: + raise FileError(file, "unable to extract metadata from video file") + + video_stream = video_streams[0] + + r_frame_rate = video_stream.get("r_frame_rate", "0") + if "/" in r_frame_rate: + num, denom = r_frame_rate.split("/") + fps = float(num) / float(denom) + else: + fps = float(r_frame_rate) + + width = int(video_stream.get("width", 0)) + height = int(video_stream.get("height", 0)) + duration = float(video_format.get("duration", 0)) + if "nb_frames" in video_stream: + frames = int(video_stream.get("nb_frames", 0)) + else: + start_time = float(video_format.get("start_time", 0)) + frames = int((duration - start_time) * fps) + format_name = video_format.get("format_name", "") + codec_name = video_stream.get("codec_name", "") + + return Video( + width=width, + height=height, + fps=fps, + duration=duration, + frames=frames, + format=format_name, + codec=codec_name, + ) + + +def video_frame_np(video: VideoFile, frame: int) -> ndarray: + """ + Reads video frame from a file and returns as numpy array. + + Args: + video (VideoFile): Video file object. + frame (int): Frame index. + + Returns: + ndarray: Video frame. + """ + if frame < 0: + raise ValueError("frame must be a non-negative integer") + + with video.open() as f: + return iio.imread(f, index=frame, plugin="pyav") # type: ignore[arg-type] + + +def validate_frame_range( + video: VideoFile, + start: int = 0, + end: Optional[int] = None, + step: int = 1, +) -> tuple[int, int, int]: + """ + Validates frame range for a video file. + + Args: + video (VideoFile): Video file object. + start (int): Start frame index (default: 0). + end (int, optional): End frame index (default: None). + step (int): Step between frames (default: 1). + + Returns: + tuple[int, int, int]: Start frame index, end frame index, and step. + """ + if start < 0: + raise ValueError("start_frame must be a non-negative integer.") + if step < 1: + raise ValueError("step must be a positive integer.") + + if end is None: + end = video_info(video).frames + + if end < 0: + raise ValueError("end_frame must be a non-negative integer.") + if start > end: + raise ValueError("start_frame must be less than or equal to end_frame.") + + return start, end, step + + +def video_frame_bytes(video: VideoFile, frame: int, format: str = "jpg") -> bytes: + """ + Reads video frame from a file and returns as image bytes. + + Args: + video (VideoFile): Video file object. + frame (int): Frame index. + format (str): Image format (default: 'jpg'). + + Returns: + bytes: Video frame image as bytes. + """ + img = video_frame_np(video, frame) + return iio.imwrite("", img, extension=f".{format}") + + +def save_video_frame( + video: VideoFile, + frame: int, + output: str, + format: str = "jpg", +) -> ImageFile: + """ + Saves video frame as a new image file. If output is a remote path, + the image file will be uploaded to the remote storage. + + Args: + video (VideoFile): Video file object. + frame (int): Frame index. + output (str): Output path, can be a local path or a remote path. + format (str): Image format (default: 'jpg'). + + Returns: + ImageFile: Image file model. + """ + img = video_frame_bytes(video, frame, format=format) + output_file = posixpath.join( + output, f"{video.get_file_stem()}_{frame:04d}.{format}" + ) + return ImageFile.upload(img, output_file) + + +def save_video_fragment( + video: VideoFile, + start: float, + end: float, + output: str, + format: Optional[str] = None, +) -> VideoFile: + """ + Saves video interval as a new video file. If output is a remote path, + the video file will be uploaded to the remote storage. + + Args: + video (VideoFile): Video file object. + start (float): Start time in seconds. + end (float): End time in seconds. + output (str): Output path, can be a local path or a remote path. + format (str, optional): Output format (default: None). If not provided, + the format will be inferred from the video fragment + file extension. + + Returns: + VideoFile: Video fragment model. + """ + if start < 0 or end < 0 or start >= end: + raise ValueError(f"Invalid time range: ({start:.3f}, {end:.3f})") + + if format is None: + format = video.get_file_ext() + + start_ms = int(start * 1000) + end_ms = int(end * 1000) + output_file = posixpath.join( + output, f"{video.get_file_stem()}_{start_ms:06d}_{end_ms:06d}.{format}" + ) + + temp_dir = tempfile.mkdtemp() + try: + output_file_tmp = posixpath.join(temp_dir, posixpath.basename(output_file)) + ffmpeg.input( + video.get_local_path(), + ss=start, + to=end, + ).output(output_file_tmp).run(quiet=True) + + with open(output_file_tmp, "rb") as f: + return VideoFile.upload(f.read(), output_file) + finally: + shutil.rmtree(temp_dir) diff --git a/tests/unit/lib/data/Big_Buck_Bunny_360_10s_1MB.mp4 b/tests/unit/lib/data/Big_Buck_Bunny_360_10s_1MB.mp4 new file mode 100644 index 000000000..9b6d89da0 Binary files /dev/null and b/tests/unit/lib/data/Big_Buck_Bunny_360_10s_1MB.mp4 differ diff --git a/tests/unit/lib/test_video.py b/tests/unit/lib/test_video.py new file mode 100644 index 000000000..851c14abe --- /dev/null +++ b/tests/unit/lib/test_video.py @@ -0,0 +1,229 @@ +import io +import os + +import pytest +from numpy import ndarray +from PIL import Image + +from datachain import VideoFragment, VideoFrame +from datachain.lib.file import FileError, ImageFile, VideoFile +from datachain.lib.video import save_video_fragment, video_frame_np + + +@pytest.fixture(autouse=True) +def video_file(catalog) -> VideoFile: + data_path = os.path.join(os.path.dirname(__file__), "data") + file_name = "Big_Buck_Bunny_360_10s_1MB.mp4" + + with open(os.path.join(data_path, file_name), "rb") as f: + file = VideoFile.upload(f.read(), file_name) + + file.ensure_cached() + return file + + +def test_get_info(video_file): + info = video_file.get_info() + assert info.model_dump() == { + "width": 640, + "height": 360, + "fps": 30.0, + "duration": 10.0, + "frames": 300, + "format": "mov,mp4,m4a,3gp,3g2,mj2", + "codec": "h264", + } + + +def test_get_info_error(): + # upload current Python file as video file to get an error while getting video meta + with open(__file__, "rb") as f: + file = VideoFile.upload(f.read(), "test.mp4") + + file.ensure_cached() + with pytest.raises(FileError): + file.get_info() + + +def test_get_frame(video_file): + frame = video_file.get_frame(37) + assert isinstance(frame, VideoFrame) + assert frame.frame == 37 + + +def test_get_frame_error(video_file): + with pytest.raises(ValueError): + video_file.get_frame(-1) + + +def test_get_frame_np(video_file): + frame = video_file.get_frame(0).get_np() + assert isinstance(frame, ndarray) + assert frame.shape == (360, 640, 3) + + +def test_get_frame_np_error(video_file): + with pytest.raises(ValueError): + video_frame_np(video_file, -1) + + +@pytest.mark.parametrize( + "format,img_format,header", + [ + ("jpg", "JPEG", [b"\xff\xd8\xff\xe0"]), + ("png", "PNG", [b"\x89PNG\r\n\x1a\n"]), + ("gif", "GIF", [b"GIF87a", b"GIF89a"]), + ], +) +def test_get_frame_bytes(video_file, format, img_format, header): + frame = video_file.get_frame(0).read_bytes(format) + assert isinstance(frame, bytes) + assert any(frame.startswith(h) for h in header) + + img = Image.open(io.BytesIO(frame)) + assert img.format == img_format + assert img.size == (640, 360) + + +@pytest.mark.parametrize("use_format", [True, False]) +def test_save_frame(tmp_path, video_file, use_format): + frame = video_file.get_frame(3) + if use_format: + frame_file = frame.save(str(tmp_path), format="jpg") + else: + frame_file = frame.save(str(tmp_path)) + assert isinstance(frame_file, ImageFile) + + frame_file.ensure_cached() + img = Image.open(frame_file.get_local_path()) + assert img.format == "JPEG" + assert img.size == (640, 360) + + +def test_get_frames(video_file): + frames = list(video_file.get_frames(10, 200, 5)) + assert len(frames) == 38 + assert all(isinstance(frame, VideoFrame) for frame in frames) + + +def test_get_all_frames(video_file): + frames = list(video_file.get_frames()) + assert len(frames) == 300 + assert all(isinstance(frame, VideoFrame) for frame in frames) + + +@pytest.mark.parametrize( + "start,end,step", + [ + (-1, None, 1), + (0, -1, 1), + (1, 0, 1), + (0, 1, -1), + ], +) +def test_get_frames_error(video_file, start, end, step): + with pytest.raises(ValueError): + list(video_file.get_frames(start, end, step)) + + +def test_save_frames(tmp_path, video_file): + frames = list(video_file.get_frames(10, 200, 5)) + frame_files = [frame.save(str(tmp_path), format="jpg") for frame in frames] + assert len(frame_files) == 38 + + for frame_file in frame_files: + frame_file.ensure_cached() + img = Image.open(frame_file.get_local_path()) + assert img.format == "JPEG" + assert img.size == (640, 360) + + +def test_get_fragment(video_file): + fragment = video_file.get_fragment(2.5, 5) + assert isinstance(fragment, VideoFragment) + assert fragment.start == 2.5 + assert fragment.end == 5 + + +def test_get_fragments(video_file): + fragments = list(video_file.get_fragments(duration=1.5)) + for i, fragment in enumerate(fragments): + assert isinstance(fragment, VideoFragment) + assert fragment.start == i * 1.5 + duration = 1.5 if i < 6 else 1.0 + assert fragment.end == fragment.start + duration + + +@pytest.mark.parametrize( + "duration,start,end", + [ + (-1, 0, 10), + (1, -1, 10), + (1, 0, -1), + (1, 2, 1), + ], +) +def test_get_fragments_error(video_file, duration, start, end): + with pytest.raises(ValueError): + list(video_file.get_fragments(duration=duration, start=start, end=end)) + + +@pytest.mark.parametrize( + "start,end", + [ + (-1, -1), + (-1, 2.5), + (5, -1), + (5, 2.5), + (5, 5), + ], +) +def test_save_fragment_error(video_file, start, end): + with pytest.raises(ValueError): + video_file.get_fragment(start, end) + + +def test_save_fragment(tmp_path, video_file): + fragment = video_file.get_fragment(2.5, 5).save(str(tmp_path)) + + fragment.ensure_cached() + assert fragment.get_info().model_dump() == { + "width": 640, + "height": 360, + "fps": 30.0, + "duration": 2.5, + "frames": 75, + "format": "mov,mp4,m4a,3gp,3g2,mj2", + "codec": "h264", + } + + +@pytest.mark.parametrize( + "start,end", + [ + (-1, 2), + (1, -1), + (2, 1), + ], +) +def test_save_video_fragment_error(video_file, start, end): + with pytest.raises(ValueError): + save_video_fragment(video_file, start, end, ".") + + +def test_save_fragments(tmp_path, video_file): + fragments = list(video_file.get_fragments(duration=1)) + fragment_files = [fragment.save(str(tmp_path)) for fragment in fragments] + assert len(fragment_files) == 10 + + for fragment in fragment_files: + fragment.ensure_cached() + assert fragment.get_info().model_dump() == { + "width": 640, + "height": 360, + "fps": 30.0, + "duration": 1, + "frames": 30, + "format": "mov,mp4,m4a,3gp,3g2,mj2", + "codec": "h264", + }