Intro to TorchData: A Walkthrough with Conceptual Captions 3M

Learn how to use TorchData and DataPipes to efficiently stream large datasets like Conceptual Captions 3M.

Frank Odom
Towards Data Science

--

Photo by Quinten de Graaf on Unsplash

Overview

When working with large datasets, especially in deep learning, it can be impractical to download them locally for training. Instead, streaming the dataset directly during training can be a more efficient approach. In this tutorial, we will introduce the TorchData library and demonstrate how to use it to stream the Conceptual Captions 3M dataset, which consists of 3 million images with their corresponding captions.

Note: Conceptual Captions is freely available under an open-source license. For more information, see the LICENSE from the official GitHub repo.

We’ll start by providing a brief overview of TorchData and its main components. Then, we’ll walk through the process of setting up our data pipeline for the Conceptual Captions 3M dataset, and finally, we’ll show an example of how to use the pipeline to stream the dataset in real-time.

This tutorial is designed to be accessible to absolute beginners, so we’ll take the time to explain each concept and code snippet in detail. Let’s get started!

Photo by Braden Collum on Unsplash

Intro to TorchData

TorchData is a library of common data loading methods for easily constructing flexible and performant data pipelines. An excerpt from the TorchData README:

It introduces composable Iterable-style and Map-style building blocks called DataPipes that work well out of the box with PyTorch’s DataLoader. These built-in DataPipes provide functionalities for loading files (from local or cloud storage), parsing, caching, transforming, filtering, and many more utilities.

DataPipes

Photo by T K on Unsplash

At the core of TorchData are DataPipes, which can be thought of as composable building blocks for data pipelines. DataPipes are simply renamed and repurposed PyTorch Datasets designed for composed usage. They take in an access function over Python data structures, __iter__ for IterDataPipes and __getitem__ for MapDataPipes, and return a new access function with a slight transformation applied.

By chaining together DataPipes, we can create sophisticated data pipelines with streamed operation as a first-class citizen. This enables us to handle large datasets efficiently and reduce the need for local storage.

Example

Let’s start with an example to get familiar with the basic concepts. Let’s create a basic DataPipe that takes an iterable of integers and doubles their values:

from torchdata.datapipes.iter import IterDataPipe

class DoublingDataPipe(IterDataPipe):
def __init__(self, source_data):
self.source_data = source_data

def __iter__(self):
for item in self.source_data:
yield item * 2

# Initialize the DataPipe with a list of integers.
source_data = [1, 2, 3, 4, 5]
doubling_data_pipe = DoublingDataPipe(source_data)

# Iterate over the DataPipe and print the results.
for doubled_item in doubling_data_pipe:
print(doubled_item)

This code defines a custom DoublingDataPipe that takes an iterable source of data (in our case, a list of integers) and yields each item from the source data multiplied by 2. When we run this code, we should see the doubled values printed:

2
4
6
8
10

TorchData also provides lots of built-in pipeline methods, which could have made this example much more concise.

The .map(), .filter(), .shuffle(), and .chain() methods, to name a few, enable us to quickly build powerful and flexible data pipelines without having to write custom DataPipes for every operation. They can be applied directly to an IterDataPipe to perform common data processing tasks, such as applying transformations, filtering data, randomizing order, and concatenating multiple DataPipes.

Let’s explore a few examples. We’ll use the source_data list from our previous example as the input for our DataPipes.

  1. .map(): Applies a function to each element in the DataPipe.
from torchdata.datapipes.iter import IterableWrapper

data_pipe = IterableWrapper([1, 2, 3, 4, 5])

# Double each element in the DataPipe.
doubled_data_pipe = data_pipe.map(lambda x: x * 2)

for item in doubled_data_pipe:
print(item)

# Output: 2, 4, 6, 8, 10

The entire DoublingDataPipe example from earlier is reproduced here in a single line of code: data_pipe.map(lambda x: x * 2).

2. .filter(): Filters the elements in the DataPipe based on a condition.

from torchdata.datapipes.iter import IterableWrapper

data_pipe = IterableWrapper([1, 2, 3, 4, 5])

# Filter out odd elements in the DataPipe.
even_data_pipe = data_pipe.filter(lambda x: x % 2 == 0)

for item in even_data_pipe:
print(item)

# Output: 2, 4

3. .shuffle(): Randomizes the order of elements in the DataPipe.

from torchdata.datapipes.iter import IterableWrapper

data_pipe = IterableWrapper([1, 2, 3, 4, 5])

# Shuffle the elements in the DataPipe.
shuffled_data_pipe = data_pipe.shuffle(buffer_size=5)

for item in shuffled_data_pipe:
print(item)

# Output: Randomly ordered elements, e.g., 3, 1, 5, 2, 4

4. .chain(): Concatenates two or more DataPipes.

from torchdata.datapipes.iter import IterableWrapper

data_pipe1 = IterableWrapper([1, 2, 3])
data_pipe2 = IterableWrapper([4, 5, 6])

# Chain the two DataPipes together.
chained_data_pipe = data_pipe1.chain(data_pipe2)

for item in chained_data_pipe:
print(item)

# Output: 1, 2, 3, 4, 5, 6

Setting Up Conceptual Captions 3M

Photo by John Schnobrich on Unsplash

In this section, we’ll walk through the process of setting up our data pipeline for the Conceptual Captions 3M dataset. This dataset consists of 3 million images and their corresponding captions, making it impractical to download locally for training. Instead, we’ll use TorchData to stream the dataset directly during training.

Dependencies

Before diving into the code, let’s first install the required dependencies for this tutorial. You’ll need the following Python packages:

  • torchdata
  • tqdm (for displaying progress bars)
  • aiohttp (for asynchronous HTTP requests)
  • Pillow (for handling images)

You can install them using pip:

pip install torchdata tqdm aiohttp Pillow

Async Helper Functions for Asynchronous Image Download

Since we’ll be streaming images from remote URLs, we need a way to efficiently download them in parallel. We’ll use the aiohttp library to make asynchronous HTTP requests.

First, let’s create a helper function async_get_image that takes an aiohttp.ClientSession and a URL as input and returns the downloaded image:

import aiohttp
from PIL import Image
import io

async def async_get_image(
session: aiohttp.ClientSession, url: str
) -> Optional[Image.Image]:
try:
resp = await session.get(url)
image_bytes = await resp.read()
return Image.open(io.BytesIO(image_bytes))
except Exception:
# If an exception occurs, such as a timeout, invalid URL, etc, just
# return None, and the caller can handle skipping this
return None

Next, we’ll create another helper function async_batch_get_images that takes a sequence of URLs, and returns a list of downloaded images. It uses aiohttp.ClientSessionto run multiple requests in parallel with minimal overhead, which is crucial for performance when fetching a large number of images from remote URLs in real-time.

async def async_batch_get_images(
urls: Sequence[str], timeout: float = 1.0
) -> List[Optional[Image.Image]]:
client_timeout = aiohttp.ClientTimeout(total=timeout)
async with aiohttp.ClientSession(timeout=client_timeout) as session:
return await asyncio.gather(*[async_get_image(session, url) for url in urls])

ParallelSampleLoader DataPipe

Photo by Tom Strecker on Unsplash

Now that we have our helper functions for downloading images, let’s create a custom ParallelSampleLoader DataPipe that takes an IterDataPipe of tuples containing image URLs and captions, and returns an iterator over the downloaded images and their corresponding captions.

from torchdata.datapipes.iter import IterDataPipe

class ParallelSampleLoader(IterDataPipe):
def __init__(
self, dp: IterDataPipe[Tuple[str, str]], buffer_size: int = 256
) -> None:
super().__init__()
self.dp = dp
self.buffer_size = buffer_size

def __iter__(self) -> Generator[Tuple[Image.Image, str], None, None]:
pipe: IterDataPipe[List[Tuple[str, str]]] = self.dp.batch(self.buffer_size)
for batch in pipe:
# The batch is a list of tuples, where the first element is the
# caption, and the second element is the URL of the image.
captions = [x[0] for x in batch]
image_urls = [x[1] for x in batch]
images = asyncio.run(async_batch_get_images(image_urls))

for image, caption in zip(images, captions):
if image is not None:
yield image, caption

Putting It All Together

Finally, we’ll create a function conceptual_captions_3m that takes a split argument (either "train" or "val") and returns an IterDataPipe of tuples containing the downloaded images and their corresponding captions.

from torchdata.datapipes.iter import HttpReader, LineReader

def _datapipe_from_tsv_url(
tsv_url: str, buffer_size: int = 256
) -> IterDataPipe[Tuple[Image.Image, str]]:
pipe = HttpReader([tsv_url])
pipe = LineReader(pipe, return_path=False)
# LineReader downloads raw bytes. Decode them to strings, then split.
pipe = pipe.map(lambda line: line.decode("utf-8").split("\t"))

return ParallelSampleLoader(pipe, buffer_size=buffer_size)

def conceptual_captions_3m(
split: str = "train", buffer_size: int = 256
) -> IterDataPipe[Tuple[Image.Image, str]]:
return _datapipe_from_tsv_url(tsv_url=TSV_URLS[split], buffer_size=buffer_size)

Conceptual Captions 3M

With our data pipeline set up, we can now use it to stream the Conceptual Captions 3M dataset in real-time. In this example, we’ll use the conceptual_captions_3m function to create an IterDataPipe for the training split and iterate over the dataset, printing out the first 10 captions and displaying their corresponding image sizes:

# Create the IterDataPipe for the training split.
data_pipe = conceptual_captions_3m(split="train")

for i, (image, caption) in enumerate(data_pipe):
if i >= 10:
break
print(f"Caption {i + 1}: {caption}")
print(f"Image size: {image.size}")

Here’s another simple example to benchmark how quickly images can be loaded through this pipeline. We use the tqdm library to create a progress bar, which displays the number of samples iterated per second.

from tqdm import tqdm

# Create the IterDataPipe for the training split.
data_pipe = conceptual_captions_3m(split="train")

for image, caption in tqdm(data_pipe):
# Don't do anything here. We just want to test the loading speed.
pass

Download speeds are very dependent on your internet connection. Virtual machines from most cloud providers have extremely fast network connections, which makes them ideal for using DataPipes. I ran the benchmark above on my Google Cloud VM, which reaches download speeds of roughly 120 images per second. For small-scale ML training using a single GPU-enabled machine, that should be more than enough speed. (Few ML models train at faster than 120 images/sec, unless you’re using more expensive GPU hardware.)

Conclusion

TorchData offers a powerful and flexible way to handle large datasets by providing composable DataPipes and built-in pipeline methods. By utilizing these tools, you can effectively stream and process your data on-the-fly without the need for downloading data to local disk beforehand. This approach not only saves time and storage resources, but enables a more seamless integration of the dataset into your project. All of the dataset logic is now codified in your Python project, and does not require detailed setup instructions in your README (common in many projects). By encapsulating the pipeline within your code, TorchData allows for better reproducibility and portability, making it an invaluable tool for modern machine learning projects dealing with large datasets.

With this tutorial, you should now have a better understanding of how to use the TorchData library to create a data pipeline for streaming large datasets like Conceptual Captions 3M. This approach can be applied to other large datasets, and it can be easily adapted for various data processing and augmentation tasks.

Photo by Vasily Koloda on Unsplash

--

--