|
| 1 | +# Copyright (c) Meta Platforms, Inc. and affiliates. |
| 2 | +# All rights reserved. |
| 3 | +# |
| 4 | +# This source code is licensed under the BSD-style license found in the |
| 5 | +# LICENSE file in the root directory of this source tree. |
| 6 | + |
| 7 | +from typing import Any, Callable, Dict, Optional |
| 8 | + |
| 9 | +from torchtune.data import InputOutputToMessages |
| 10 | +from torchtune.datasets._sft import SFTDataset |
| 11 | +from torchtune.modules.transforms import Transform |
| 12 | + |
| 13 | + |
| 14 | +def vqa_dataset( |
| 15 | + model_transform: Transform, |
| 16 | + *, |
| 17 | + source: str, |
| 18 | + image_dir: str = None, |
| 19 | + column_map: Optional[Dict[str, str]] = None, |
| 20 | + new_system_prompt: Optional[str] = None, |
| 21 | + filter_fn: Optional[Callable] = None, |
| 22 | + split: str = "train", |
| 23 | + **load_dataset_kwargs: Dict[str, Any], |
| 24 | +) -> SFTDataset: |
| 25 | + """ |
| 26 | + Configure a custom visual question answer dataset with separate columns for user question, image, and model response. |
| 27 | +
|
| 28 | + This builder function can be used to configure a custom visual question answer dataset directly from the yaml config |
| 29 | + as an alternative to :class:`~torchtune.datasets.SFTDataset`, as it is made to be config friendly. |
| 30 | +
|
| 31 | + The dataset should follow this format: |
| 32 | +
|
| 33 | + .. code-block:: text |
| 34 | +
|
| 35 | + | input | image | output | |
| 36 | + |-----------------|-----------------|------------------| |
| 37 | + | "user prompt" | images/1.jpg | "model response" | |
| 38 | +
|
| 39 | + If your column names are different, you can use the ``column_map`` parameter to change |
| 40 | + the expected column names. For example, if your dataset has columns ``"question"``, |
| 41 | + ``"answer"`` and ``"picture"`` you can use: |
| 42 | +
|
| 43 | + column_map = {"input": "question", "output": "answer", "image": "picture"} |
| 44 | +
|
| 45 | + Args: |
| 46 | + model_transform (Transform): callable that applies model-specific pre-processing to the sample. |
| 47 | + This includes tokenization and any modality-specific transforms. It is expected to return at |
| 48 | + minimum ``"tokens"`` and ``"mask"`` keys. |
| 49 | + source (str): path to dataset repository on Hugging Face. For local datasets, |
| 50 | + define source as the data file type (e.g. "json", "csv", "text"), pass |
| 51 | + in the filepath in ``data_files``, and set ``split="train"``. See `Hugging Face's |
| 52 | + <https://huggingface.co/docs/datasets/en/package_reference/loading_methods#datasets.load_dataset.path>`_ |
| 53 | + ``load_dataset`` for more details. |
| 54 | + image_dir (str): path to the directory containing the images that is prepended to all image |
| 55 | + paths in the dataset. For example, if ``image_dir="/home/user/dataset/"` and the sample image path |
| 56 | + was ``"images/1.jpg"``, the final image path that will be loaded is ``"/home/user/dataset/images/1.jpg"``. |
| 57 | + If None, assume images are available in current working directory or are located |
| 58 | + on a remote url. For text-only, leave as None. Default is None. |
| 59 | + column_map (Optional[Dict[str, str]]): a mapping to change the expected "input", |
| 60 | + "output", and "image" column names to the actual column names in the dataset. Keys should be "input", |
| 61 | + "output", and "image, and values should be the actual column names. |
| 62 | + Default is None, keeping the default "input" and "output", and "image" column names. |
| 63 | + new_system_prompt (Optional[str]): if specified, prepend a system message. This can |
| 64 | + serve as instructions to guide the model response. Setting this will OVERRIDE any system |
| 65 | + messages already present in the dataset. Default is None. |
| 66 | + filter_fn (Optional[Callable]): callable used to filter the dataset prior to any pre-processing. See |
| 67 | + the Hugging Face `docs <https://huggingface.co/docs/datasets/v2.20.0/process#select-and-filter>`_ for more |
| 68 | + details. |
| 69 | + split (str): ``split`` argument for ``datasets.load_dataset``. You can use this argument to load a subset |
| 70 | + of a given split, e.g. ``split="train[:10%]"``. Default is "train". |
| 71 | + **load_dataset_kwargs (Dict[str, Any]): additional keyword arguments to pass to ``load_dataset``, |
| 72 | + such as ``data_files`` or ``split``. |
| 73 | +
|
| 74 | + Examples: |
| 75 | +
|
| 76 | + :: |
| 77 | +
|
| 78 | + my_dataset.json |
| 79 | + [ |
| 80 | + { |
| 81 | + "question": "What is presented on the image?", |
| 82 | + "answer": "PyTorch logo.", |
| 83 | + "picture": "rgb_pytorch.png" |
| 84 | + }, |
| 85 | + { |
| 86 | + ... |
| 87 | + }, |
| 88 | + ..., |
| 89 | + ] |
| 90 | +
|
| 91 | + :: |
| 92 | +
|
| 93 | + >>> from torchtune.datasets.multimodal import vqa_dataset |
| 94 | + >>> dataset = vqa_dataset( |
| 95 | + ... model_transform=model_transform, |
| 96 | + ... source="json", |
| 97 | + ... data_files="my_dataset.json", |
| 98 | + ... column_map={ |
| 99 | + ... "input": "question", |
| 100 | + ... "output": "answer", |
| 101 | + ... "image": "picture" |
| 102 | + ... }, |
| 103 | + ... split="train", |
| 104 | + ... ) |
| 105 | + >>> tokens = dataset[0]["tokens"] |
| 106 | + >>> model_transform.decode(tokens) |
| 107 | + "What is presented on the image?PyTorch logo." |
| 108 | +
|
| 109 | + This can also be accomplished via the yaml config: |
| 110 | +
|
| 111 | + .. code-block:: yaml |
| 112 | +
|
| 113 | + dataset: |
| 114 | + _component_: torchtune.datasets.multimodal.vqa_dataset |
| 115 | + source: json |
| 116 | + data_files: my_dataset.json |
| 117 | + column_map: |
| 118 | + input: question |
| 119 | + output: answer |
| 120 | + image: picture |
| 121 | + split: train |
| 122 | +
|
| 123 | + Returns: |
| 124 | + SFTDataset: the configured :class:`~torchtune.datasets.SFTDataset` |
| 125 | + """ |
| 126 | + message_transform = InputOutputToMessages( |
| 127 | + column_map=column_map, new_system_prompt=new_system_prompt, image_dir=image_dir |
| 128 | + ) |
| 129 | + |
| 130 | + ds = SFTDataset( |
| 131 | + source=source, |
| 132 | + message_transform=message_transform, |
| 133 | + model_transform=model_transform, |
| 134 | + filter_fn=filter_fn, |
| 135 | + split=split, |
| 136 | + **load_dataset_kwargs, |
| 137 | + ) |
| 138 | + return ds |
0 commit comments