Image class, variants and internal data augmentation pipeline

The fastai Image classes

The fastai library is built such that the pictures loaded are wrapped in an Image. This Image contains the array of pixels associated to the picture, but also has a lot of built-in functions that will help the fastai library to process transformations applied to the corresponding image. There are also sub-classes for special types of image-like objects:

See the following sections for documentation of all the details of these classes. But first, let's have a quick look at the main functionality you'll need to know about.

Opening an image and converting to an Image object is easily done by using the open_image function:

img = open_image('imgs/cat_example.jpg')
img

To look at the picture that this Image contains, you can also use its show method. It will show a resized version and has more options to customize the display.

img.show()

This show method can take a few arguments (see the documentation of Image.show for details) but the two we will use the most in this documentation are:

  • ax which is the matplolib.pyplot axes on which we want to show the image
  • title which is an optional title we can give to the image.
_,axs = plt.subplots(1,4,figsize=(12,4))
for i,ax in enumerate(axs): img.show(ax=ax, title=f'Copy {i+1}')

If you're interested in the tensor of pixels, it's stored in the data attribute of an Image.

img.data.shape
torch.Size([3, 500, 394])

The Image classes

Image is the class that wraps every picture in the fastai library. It is subclassed to create ImageSegment and ImageBBox when dealing with segmentation and object detection tasks.

class Image[source][test]

Image(px:Tensor) :: ItemBase

Tests found for Image:

  • pytest -sv tests/test_vision_transform.py::test_mask_data_aug [source]

Some other tests where Image is used:

  • pytest -sv tests/test_vision_image.py::test_image_resize_same_size_shortcut [source]

To run tests please refer to this guide.

Support applying transforms to image data in px.

Most of the functions of the Image class deal with the internal pipeline of transforms, so they are only shown at the end of this page. The easiest way to create one is through the function open_image, as we saw before.

open_image[source][test]

open_image(fn:PathOrStr, div:bool=True, convert_mode:str='RGB', after_open:Callable=None) → Image

No tests found for open_image. To contribute a test please refer to this guide and this discussion.

Return Image object created from image in file fn.

If div=True, pixel values are divided by 255. to become floats between 0. and 1. convert_mode is passed to PIL.Image.convert.

With the following example, you can get a feel of how open_image working with different convert_mode. For all the modes see the source here.

from fastai.vision import *
path_data = untar_data(URLs.PLANET_TINY); path_data.ls()
[PosixPath('/Users/Natsume/.fastai/data/planet_tiny/labels.csv'),
 PosixPath('/Users/Natsume/.fastai/data/planet_tiny/train')]
il = ImageList.from_folder(path_data/'train'); il
ImageList (200 items)
Image (3, 128, 128),Image (3, 128, 128),Image (3, 128, 128),Image (3, 128, 128),Image (3, 128, 128)
Path: /Users/Natsume/.fastai/data/planet_tiny/train
il.convert_mode = 'L'
il.open(il.items[10])
mode = '1'
open_image(il.items[10],convert_mode=mode)

As we saw, in a Jupyter Notebook, the representation of an Image is its underlying picture (shown to its full size). On top of containing the tensor of pixels of the image (and automatically doing the conversion after decoding the image), this class contains various methods for the implementation of transforms. The Image.show method also allows to pass more arguments:

Image.show[source][test]

Image.show(ax:Axes=None, figsize:tuple=(3, 3), title:Optional[str]=None, hide_axis:bool=True, cmap:str=None, y:Any=None, **kwargs)

No tests found for show. To contribute a test please refer to this guide and this discussion.

Show image on ax with title, using cmap if single-channel, overlaid with optional y

  • ax: matplotlib.pyplot axes on which show the image
  • figsize: Size of the figure
  • title: Title to display on top of the graph
  • hide_axis: If True, the axis of the graph are hidden
  • cmap: Color map to use
  • y: Potential target to be superposed on the same graph (mask, bounding box, points)

This allows us to completely customize the display of an Image. We'll see examples of the y functionality below with segmentation and bounding boxes tasks, for now here is an example using the other features.

img.show(figsize=(2, 1), title='Little kitten')
img.show(figsize=(10,5), title='Big kitten')

With the following example, you will get a feel of how to set cmap for Image.show.

See matplotlib docs for cmap options here. This is how defaults.cmap is defined in fastai, see source here.

defaults = SimpleNamespace(cpus=_default_cpus, cmap='viridis', return_fig=False, silent=False)
img.shape
torch.Size([3, 500, 394])

As cmap works on a single channel, so it is necessary to set convert_mode='L' so that the image channel will be shrinked to 1.

img = open_image('imgs/cat_example.jpg', convert_mode='L'); print(img.shape)
img
torch.Size([1, 500, 394])
img.show(cmap='binary')
img.show(cmap='pink')
defaults.cmap = 'Blues'
img.show()

An Image object also has a few attributes that can be useful:

  • Image.data gives you the underlying tensor of pixel
  • Image.shape gives you the size of that tensor (channels x height x width)
  • Image.size gives you the size of the image (height x width)
img.data, img.shape, img.size
(tensor([[[0.0627, 0.0353, 0.0235,  ..., 0.3647, 0.3843, 0.3843],
          [0.0275, 0.0196, 0.0235,  ..., 0.3686, 0.3843, 0.3882],
          [0.0314, 0.0275, 0.0431,  ..., 0.3804, 0.3804, 0.3804],
          ...,
          [0.3725, 0.4392, 0.4431,  ..., 0.6235, 0.6549, 0.6549],
          [0.3961, 0.4745, 0.4784,  ..., 0.6706, 0.6588, 0.6588],
          [0.4510, 0.5373, 0.5412,  ..., 0.7294, 0.6549, 0.6549]]]),
 torch.Size([1, 500, 394]),
 torch.Size([500, 394]))

For a segmentation task, the target is usually a mask. The fastai library represents it as an ImageSegment object.

class ImageSegment[source][test]

ImageSegment(px:Tensor) :: Image

Tests found for ImageSegment:

  • pytest -sv tests/test_vision_transform.py::test_mask_data_aug [source]

To run tests please refer to this guide.

Support applying transforms to segmentation masks data in px.

To easily open a mask, the function open_mask plays the same role as open_image:

open_mask[source][test]

open_mask(fn:PathOrStr, div=False, convert_mode='L', after_open:Callable=None) → ImageSegment

No tests found for open_mask. To contribute a test please refer to this guide and this discussion.

Return ImageSegment object create from mask in file fn. If div, divides pixel values by 255.

open_mask('imgs/mask_example.png')

Run length encoded masks

From time to time, you may encouter mask data as run lengh encoding string instead of picture.

df = pd.read_csv('imgs/mask_rle_sample.csv')
encoded_str = df.iloc[1]['rle_mask']; 
df[:2]
img rle_mask
0 00087a6bd4dc_01.jpg 879386 40 881253 141 883140 205 885009 17 8850...
1 00087a6bd4dc_02.jpg 873779 4 875695 7 877612 9 879528 12 881267 15...

You can also read a mask in run length encoding, with an extra argument shape for image size

mask = open_mask_rle(df.iloc[0]['rle_mask'], shape=(1918, 1280)).resize((1,128,128))
mask

open_mask_rle[source][test]

open_mask_rle(mask_rle:str, shape:Tuple[int, int]) → ImageSegment

No tests found for open_mask_rle. To contribute a test please refer to this guide and this discussion.

Return ImageSegment object create from run-length encoded string in mask_lre with size in shape.

The open_mask_rle simply make use of the helper function rle_decode

rle_decode(encoded_str, (1912, 1280)).shape
(1912, 1280)

rle_decode[source][test]

rle_decode(mask_rle:str, shape:Tuple[int, int]) → ndarray

Tests found for rle_decode:

  • pytest -sv tests/test_vision_image.py::test_rle_decode_empty_str [source]
  • pytest -sv tests/test_vision_image.py::test_rle_decode_with_str [source]

To run tests please refer to this guide.

Return an image array from run-length encoded string mask_rle with shape.

You can also convert ImageSegment to run length encoding.

type(mask)
fastai.vision.image.ImageSegment
rle_encode(mask.data)
'5943 21 6070 25 6197 26 6324 28 6452 29 6579 30 6707 31 6835 31 6962 32 7090 33 7217 34 7345 35 7473 35 7595 2 7600 36 7722 5 7728 37 7766 4 7850 43 7894 5 7978 43 8022 5 8106 49 8238 44 8366 40 8494 41 8621 42 8748 44 8875 46 9003 47 9130 48 9258 49 9386 49 9513 50 9641 51 9769 51 9897 51 10024 52 10152 53 10280 53 10408 53 10536 53 10664 53 10792 53 10920 53 11048 53 11176 53 11304 53 11432 53 11560 53 11688 53 11816 53 11944 53 12072 53 12200 53 12328 53 12456 53 12584 53 12712 53 12840 53 12968 53 13097 51 13225 51 13353 51 13481 51 13610 49 13742 44 13880 30'

rle_encode[source][test]

rle_encode(img:ndarray) → str

Tests found for rle_encode:

  • pytest -sv tests/test_vision_image.py::test_rle_encode_all_zero_array [source]
  • pytest -sv tests/test_vision_image.py::test_rle_encode_with_array [source]

To run tests please refer to this guide.

Return run-length encoding string from img.

An ImageSegment object has the same properties as an Image. The only difference is that when applying the transformations to an ImageSegment, it will ignore the functions that deal with lighting and keep values of 0 and 1. As explained earlier, it's easy to show the segmentation mask over the associated Image by using the y argument of show_image.

img = open_image('imgs/car_example.jpg')
mask = open_mask('imgs/mask_example.png')
_,axs = plt.subplots(1,3, figsize=(8,4))
img.show(ax=axs[0], title='no mask')
img.show(ax=axs[1], y=mask, title='masked')
mask.show(ax=axs[2], title='mask only', alpha=1.)

When the targets are a bunch of points, the following class will help.

class ImagePoints[source][test]

ImagePoints(flow:FlowField, scale:bool=True, y_first:bool=True) :: Image

Tests found for ImagePoints:

  • pytest -sv tests/test_vision_transform.py::test_points_data_aug [source]

To run tests please refer to this guide.

Support applying transforms to a flow of points.

Create an ImagePoints object from a flow of coordinates. Coordinates need to be scaled to the range (-1,1) which will be done in the intialization if scale is left as True. Convention is to have point coordinates in the form [y,x] unless y_first is set to False.

img = open_image('imgs/face_example.jpg')
pnts = torch.load('points.pth')
pnts = ImagePoints(FlowField(img.size, pnts))
img.show(y=pnts)

Note that the raw points are gathered in a FlowField object, which is a class that wraps together a bunch of coordinates with the corresponding image size. In fastai, we expect points to have the y coordinate first by default. The underlying data of pnts is the flow of points scaled from -1 to 1 (again with the y coordinate first):

pnts.data[:10]
tensor([[-0.1875, -0.6000],
        [-0.0500, -0.5875],
        [ 0.0750, -0.5750],
        [ 0.2125, -0.5750],
        [ 0.3375, -0.5375],
        [ 0.4500, -0.4875],
        [ 0.5250, -0.3750],
        [ 0.5750, -0.2375],
        [ 0.5875, -0.1000],
        [ 0.5750,  0.0375]])

For an objection detection task, the target is a bounding box containg the picture.

class ImageBBox[source][test]

ImageBBox(flow:FlowField, scale:bool=True, y_first:bool=True, labels:Collection[T_co]=None, classes:dict=None, pad_idx:int=0) :: ImagePoints

Tests found for ImageBBox:

  • pytest -sv tests/test_vision_transform.py::test_bbox_data_aug [source]

To run tests please refer to this guide.

Support applying transforms to a flow of bounding boxes.

Create an ImageBBox object from a flow of coordinates. Those coordinates are expected to be in a FlowField with an underlying flow of size 4N, if we have N bboxes, describing for each box the top left, top right, bottom left, bottom right corners. Coordinates need to be scaled to the range (-1,1) which will be done in the intialization if scale is left as True. Convention is to have point coordinates in the form [y,x] unless y_first is set to False. labels is an optional collection of labels, which should be the same size as flow. pad_idx is used if the set of transform somehow leaves the image without any bounding boxes.

To create an ImageBBox, you can use the create helper function that takes a list of bounding boxes, the height of the input image, and the width of the input image. Each bounding box is represented by a list of four numbers: the coordinates of the corners of the box with the following convention: top, left, bottom, right.

create[source][test]

create(h:int, w:int, bboxes:Collection[Collection[int]], labels:Collection[T_co]=None, classes:dict=None, pad_idx:int=0, scale:bool=True) → ImageBBox

No tests found for create. To contribute a test please refer to this guide and this discussion.

Create an ImageBBox object from bboxes.

  • h: height of the input image
  • w: width of the input image
  • bboxes: list of bboxes (each of those being four integers with the top, left, bottom, right convention)
  • labels: labels of the images (as indexes)
  • classes: the corresponding classes
  • pad_idx: padding index that will be used to group the ImageBBox in a batch
  • scale: if True, will scale the bounding boxes from -1 to 1

We need to pass the dimensions of the input image so that ImageBBox can internally create the FlowField. Again, the Image.show method will display the bounding box on the same image if it's passed as a y argument.

img = open_image('imgs/car_bbox.jpg')
bbox = ImageBBox.create(*img.size, [[96, 155, 270, 351]], labels=[0], classes=['car'])
img.show(y=bbox)

To help with the conversion of images or to show them, we use these helper functions:

show_image[source][test]

show_image(img:Image, ax:Axes=None, figsize:tuple=(3, 3), hide_axis:bool=True, cmap:str='binary', alpha:float=None, **kwargs) → Axes

No tests found for show_image. To contribute a test please refer to this guide and this discussion.

Display Image in notebook.

pil2tensor[source][test]

pil2tensor(image:ndarray, dtype:dtype) → Tensor

Tests found for pil2tensor:

  • pytest -sv tests/test_vision_data.py::test_vision_pil2tensor [source]
  • pytest -sv tests/test_vision_data.py::test_vision_pil2tensor_16bit [source]
  • pytest -sv tests/test_vision_data.py::test_vision_pil2tensor_numpy [source]

To run tests please refer to this guide.

Convert PIL style image array to torch style image tensor.

pil2tensor(PIL.Image.open('imgs/cat_example.jpg').convert("RGB"), np.float32).div_(255).size() 
torch.Size([3, 500, 394])
pil2tensor(PIL.Image.open('imgs/cat_example.jpg').convert("I"), np.float32).div_(255).size()
torch.Size([1, 500, 394])
pil2tensor(PIL.Image.open('imgs/mask_example.png').convert("L"), np.float32).div_(255).size()
torch.Size([1, 128, 128])
pil2tensor(np.random.rand(224,224,3).astype(np.float32), np.float32).size()
torch.Size([3, 224, 224])
pil2tensor(PIL.Image.open('imgs/cat_example.jpg'), np.float32).div_(255).size()
torch.Size([3, 500, 394])
pil2tensor(PIL.Image.open('imgs/mask_example.png'), np.float32).div_(255).size()
torch.Size([1, 128, 128])

image2np[source][test]

image2np(image:Tensor) → ndarray

No tests found for image2np. To contribute a test please refer to this guide and this discussion.

Convert from torch style image to numpy/matplotlib style.

scale_flow[source][test]

scale_flow(flow, to_unit=True)

No tests found for scale_flow. To contribute a test please refer to this guide and this discussion.

Scale the coords in flow to -1/1 or the image size depending on to_unit.

bb2hw[source][test]

bb2hw(a:Collection[int]) → ndarray

No tests found for bb2hw. To contribute a test please refer to this guide and this discussion.

Convert bounding box points from (width,height,center) to (height,width,top,left).

show_doc(tis2hw)

tis2hw[source][test]

tis2hw(size:Union[int, TensorImageSize]) → Tuple[int, int]

Tests found for tis2hw:

  • pytest -sv tests/test_vision_image.py::test_tis2hw_2dims [source]
  • pytest -sv tests/test_vision_image.py::test_tis2hw_3dims [source]
  • pytest -sv tests/test_vision_image.py::test_tis2hw_int [source]
  • pytest -sv tests/test_vision_image.py::test_tis2hw_str_raises_an_error [source]

To run tests please refer to this guide.

Convert int or TensorImageSize to (height,width) of an image.

If a size is provided as (int, int), tis2hw will return it as it is. If a size is passed in as str, tis2hw will raise a RuntimeError

Visualization functions

show_all[source][test]

show_all(imgs:Collection[Image], r:int=1, c:Optional[int]=None, figsize=(12, 6))

No tests found for show_all. To contribute a test please refer to this guide and this discussion.

Show all imgs using r rows

plot_flat[source][test]

plot_flat(r, c, figsize)

No tests found for plot_flat. To contribute a test please refer to this guide and this discussion.

Shortcut for enumerate(subplots.flatten())

plot_multi[source][test]

plot_multi(func:Callable[int, int, Axes, NoneType], r:int=1, c:int=1, figsize:Tuple=(12, 6))

No tests found for plot_multi. To contribute a test please refer to this guide and this discussion.

Call func for every combination of r,c on a subplot

Internally, it first creates r $\times$ c number of subplots, assigned into axes, and then loops through each of axes to create a plot with the func.

show_multi[source][test]

show_multi(func:Callable[int, int, Image], r:int=1, c:int=1, figsize:Tuple=(9, 9))

No tests found for show_multi. To contribute a test please refer to this guide and this discussion.

Call func(i,j).show(ax) for every combination of r,c

Applying transforms

All the transforms available for data augmentation in computer vision are defined in the vision.transform module. When we want to apply them to an Image, we use this method:

apply_tfms[source][test]

apply_tfms(tfms:Union[Callable, Collection[Callable]], do_resolve:bool=True, xtra:Optional[Dict[Callable, dict]]=None, size:Union[int, TensorImageSize, NoneType]=None, resize_method:ResizeMethod=None, mult:int=None, padding_mode:str='reflection', mode:str='bilinear', remove_out:bool=True) → Tensor

No tests found for apply_tfms. To contribute a test please refer to this guide and this discussion.

Apply all tfms to the Image, if do_resolve picks value for random args.

  • tfms: Transform or list of Transform
  • do_resolve: if False, the values of random parameters are kept from the last draw
  • xtra: extra arguments to pass to the transforms
  • size: desired target size
  • mult: makes sure the final size is a multiple of mult
  • resize_method: how to get to the final size (crop, pad, squish)
  • padding_mode: how to pad the image ('zeros', 'border', 'reflection')

Before showing examples, let's take a few moments to comment those arguments a bit more:

  • do_resolve decides if we resolve the random arguments by drawing new numbers or not. The intended use is to have the tfms applied to the input x with do_resolve=True, then, if the target y needs to be applied data augmentation (if it's a segmentation mask or bounding box), apply the tfms to y with do_resolve=False.
  • mult default value is very important to make sure your image can pass through most recent CNNs: they divide the size of the input image by 2 multiple times so both dimensions of your picture should be mutliples of at least 32. Only change the value of this parameter if you know it will be accepted by your model.

Here are a few helper functions to help us load the examples we saw before.

def get_class_ex(): return open_image('imgs/cat_example.jpg')
def get_seg_ex(): return open_image('imgs/car_example.jpg'), open_mask('imgs/mask_example.png')
def get_pnt_ex():
    img = open_image('imgs/face_example.jpg')
    pnts = torch.load('points.pth')
    return img, ImagePoints(FlowField(img.size, pnts))
def get_bb_ex():
    img = open_image('imgs/car_bbox.jpg')
    return img, ImageBBox.create(*img.size, [[96, 155, 270, 351]], labels=[0], classes=['car'])

Now let's grab our usual bunch of transforms and see what they do.

tfms = get_transforms()
_, axs = plt.subplots(2,4,figsize=(12,6))
for ax in axs.flatten():
    img = get_class_ex().apply_tfms(tfms[0], get_class_ex(), size=224)
    img.show(ax=ax)

Now let's check what it gives for a segmentation task. Note that, as instructed by the documentation of apply_tfms, we first apply the transforms to the input, and then apply them to the target while adding do_resolve=False.

tfms = get_transforms()
_, axs = plt.subplots(2,4,figsize=(12,6))
for ax in axs.flatten():
    img,mask = get_seg_ex()
    img = img.apply_tfms(tfms[0], size=224)
    mask = mask.apply_tfms(tfms[0], do_resolve=False, size=224)
    img.show(ax=ax, y=mask)

Internally, each transforms saves the values it randomly picked into a dictionary called resolved, which it can reuse for the target.

tfms[0][4]
RandTransform(tfm=TfmAffine (zoom), kwargs={'scale': (1.0, 1.1), 'row_pct': (0, 1), 'col_pct': (0, 1)}, p=0.75, resolved={'scale': 1.0479406124400892, 'row_pct': 0.4050331333782575, 'col_pct': 0.6185771791644814}, do_run=True, is_random=True, use_on_y=True)

For points, ImagePoints will apply the transforms to the coordinates.

tfms = get_transforms()
_, axs = plt.subplots(2,4,figsize=(12,6))
for ax in axs.flatten():
    img,pnts = get_pnt_ex()
    img = img.apply_tfms(tfms[0], size=224)
    pnts = pnts.apply_tfms(tfms[0], do_resolve=False, size=224)
    img.show(ax=ax, y=pnts)

Now for the bounding box, the ImageBBox will automatically update the coordinates of the two opposite corners in its data attribute.

tfms = get_transforms()
_, axs = plt.subplots(2,4,figsize=(12,6))
for ax in axs.flatten():
    img,bbox = get_bb_ex()
    img = img.apply_tfms(tfms[0], size=224)
    bbox = bbox.apply_tfms(tfms[0], do_resolve=False, size=224)
    img.show(ax=ax, y=bbox)

Fastai internal pipeline

What does a transform do?

Typically, a data augmentation operation will randomly modify an image input. This operation can apply to pixels (when we modify the contrast or brightness for instance) or to coordinates (when we do a rotation, a zoom or a resize). The operations that apply to pixels can easily be coded in numpy/pytorch, directly on an array/tensor but the ones that modify the coordinates are a bit more tricky.

They usually come in three steps: first we create a grid of coordinates for our picture: this is an array of size h * w * 2 (h for height, w for width in the rest of this post) that contains in position i,j two floats representing the position of the pixel (i,j) in the picture. They could simply be the integers i and j, but since most transformations are centered with the center of the picture as origin, they’re usually rescaled to go from -1 to 1, (-1,-1) being the top left corner of the picture, (1,1) the bottom right corner (and (0,0) the center), and this can be seen as a regular grid of size h * w. Here is a what our grid would look like for a 5px by 5px image.

Example of grid

Then, we apply the transformation to modify this grid of coordinates. For instance, if we want to apply an affine transformation (like a rotation) we will transform each of those vectors x of size 2 by A @ x + b at every position in the grid. This will give us the new coordinates, as seen here in the case of our previous grid.

Example of grid rotated

There are two problems that arise after the transformation: the first one is that the pixel values won’t fall exactly on the grid, and the other is that we can get values that get out of the grid (one of the coordinates is greater than 1 or lower than -1).

To solve the first problem, we use an interpolation. If we forget the rescale for a minute and go back to coordinates being integers, the result of our transformation gives us float coordinates, and we need to decide, for each (i,j), which pixel value in the original picture we need to take. The most basic interpolation called nearest neighbor would just round the floats and take the nearest integers. If we think in terms of the grid of coordinates (going from -1 to 1), the result of our transformation gives a point that isn’t in the grid, and we replace it by its nearest neighbor in the grid.

To be smarter, we can perform a bilinear interpolation. This takes an average of the values of the pixels corresponding to the four points in the grid surrounding the result of our transformation, with weights depending on how close we are to each of those points. This comes at a computational cost though, so this is where we have to be careful.

As for the values that go out of the picture, we treat them by padding it either:

  • by adding zeros on the side, so the pixels that fall out will be black (zero padding)
  • by replacing them by the value at the border (border padding)
  • by mirroring the content of the picture on the other side (reflection padding).

Be smart and efficient

Usually, data augmentation libraries have separated the different operations. So for a resize, we’ll go through the three steps above, then if we do a random rotation, we’ll go again to do those steps, then for a zoom etc... The fastai library works differently in the sense that it will do all the transformations on the coordinates at the same time, so that we only do those three steps once, especially the last one (the interpolation) is the most heavy in computation.

The first thing is that we can regroup all affine transforms in just one pass (because an affine transform composed by an affine transform is another affine transform). This is already done in some other libraries but we pushed it one step further. We integrated the resize, the crop and any non-affine transformation of the coordinates in the same process. Let’s dig in!

  • In step 1, when we create the grid, we use the new size we want for our image, so new_h, new_w (and not h, w). This takes care of the resize operation.
  • In step 2, we do only one affine transformation, by multiplying all the affine matrices of the transforms we want to do beforehand (those are 3 by 3 matrices, so it’s super fast). Then we apply to the coordinates any non-affine transformation we might want (jitter, perspective wrappin, etc) before...
  • Step 2.5: we crop (either center or randomly) the coordinates we want to keep. Cropping could have been done at any point, but by doing it just before the interpolation, we don’t compute pixel values that won’t be used at the end, gaining again a bit of efficiency
  • Finally step 3: the final interpolation. Afterward, we can apply on the picture all the tranforms that operate pixel-wise (brightness or contrast for instance) and we’re done with data augmentation.

Note that the transforms operating on pixels are applied in two phases:

  • first the transforms that deal with lighting properties are applied to the logits of the pixels. We group them together so we only need to do the conversion pixels -> logits -> pixels transformation once.
  • then we apply the transforms that modify the pixel.

This is why all transforms have an attribute (such as TfmAffine, TfmCoord, TfmCrop or TfmPixel) so that the fastai library can regroup them and apply them all together at the right step. In terms of implementation:

  • _affine_grid is reponsible for creating the grid of coordinates
  • _affine_mult is in charge of doing the affine multiplication on that grid
  • _grid_sampleis the function that is responsible for the interpolation step

Final result

TODO: add a comparison of speeds.

Adding a new transformation doesn't impact performance much (since the costly steps are done only once). In contrast with other libraries with classic data augmentation implementations, augmentation usually result in a longer training time.

In terms of final result, doing only one interpolation also gives a better result. If we stack several transforms and do an interpolation on each one, we approximate the true value of our coordinates in some way. This tends to blur the image a bit, which often negatively affects performance. By regrouping all the transformations together and only doing this step at the end, the image is often less blurry and the model often performs better.

See how the same rotation then zoom done separately (so there are two interpolations):

Image interpolated twice

is blurrier than regrouping the transforms and doing just one interpolation:

Image interpolated once

`ResizeMethod`[test]

Enum = [CROP, PAD, SQUISH, NO]

No tests found for ResizeMethod. To contribute a test please refer to this guide and this discussion.

Resize methods to transform an image to a given size:

  • crop: resize so that the image fits in the desired canvas on its smaller side and crop
  • pad: resize so that the image fits in the desired canvas on its bigger side and crop
  • squish: resize theimage by squishing it in the desired canvas
  • no: doesn't resize the image

Transform classes

The basic class that defines transformation in the fastai library is Transform.

class Transform[source][test]

Transform(func:Callable, order:Optional[int]=None)

No tests found for Transform. To contribute a test please refer to this guide and this discussion.

Create a Transform for func and assign it a priority order.

class RandTransform[source][test]

RandTransform(tfm:Transform, kwargs:dict, p:float=1.0, resolved:dict=<factory>, do_run:bool=True, is_random:bool=True, use_on_y:bool=True)

No tests found for RandTransform. To contribute a test please refer to this guide and this discussion.

Wrap Transform to add randomized execution.

Each argument of func in kwargs is analyzed and if it has a type annotaiton that is a random function, this function will be called to pick a value for it. This value will be stored in the resolved dictionary. Following the same idea, p is the probability for func to be called and do_run will be set to True if it was the cause, False otherwise. Setting is_random to False allows to send specific values for each parameter. use_on_y is a parameter to further control transformations for targets (e.g. Segmentation Masks). Assuming transformations on labels are turned on using tfm_y=True (in your Data Blocks pipeline), use_on_y=False can disable the transformation for labels.

resolve[source][test]

resolve()

No tests found for resolve. To contribute a test please refer to this guide and this discussion.

Bind any random variables in the transform.

To handle internally the data augmentation as explained earlier, each Transform as a type, so that the fastai library can regoup them together efficiently. There are five types of Transform which all work as decorators for a deterministic function.

class TfmAffine[source][test]

TfmAffine(func:Callable, order:Optional[int]=None) :: Transform

No tests found for TfmAffine. To contribute a test please refer to this guide and this discussion.

Decorator for affine tfm funcs.

func should return the 3 by 3 matrix representing the transform. The default order is 5 for such transforms.

class TfmCoord[source][test]

TfmCoord(func:Callable, order:Optional[int]=None) :: Transform

No tests found for TfmCoord. To contribute a test please refer to this guide and this discussion.

Decorator for coord tfm funcs.

func should take two mandatory arguments: c (the flow of coordinate) and img_size (the size of the corresponding image) and return the modified flow of coordinates. The default order is 4 for such transforms.

class TfmLighting[source][test]

TfmLighting(func:Callable, order:Optional[int]=None) :: Transform

No tests found for TfmLighting. To contribute a test please refer to this guide and this discussion.

Decorator for lighting tfm funcs.

func takes the logits of the pixel tensor and changes them. The default order is 8 for such transforms.

class TfmPixel[source][test]

TfmPixel(func:Callable, order:Optional[int]=None) :: Transform

No tests found for TfmPixel. To contribute a test please refer to this guide and this discussion.

Decorator for pixel tfm funcs.

func takes the pixel tensor and modifies it. The default order is 10 for such transforms.

class TfmCrop[source][test]

TfmCrop(func:Callable, order:Optional[int]=None) :: TfmPixel

No tests found for TfmCrop. To contribute a test please refer to this guide and this discussion.

Decorator for crop tfm funcs.

This is a special case of TfmPixel with order set to 99.

Internal functions of the Image classes

All the Image classes have the same internal functions that deal with data augmentation.

affine[source][test]

affine(func:AffineFunc, *args, **kwargs) → Image

No tests found for affine. To contribute a test please refer to this guide and this discussion.

Equivalent to image.affine_mat = image.affine_mat @ func().

clone[source][test]

clone()

No tests found for clone. To contribute a test please refer to this guide and this discussion.

Mimic the behavior of torch.clone for Image objects.

coord[source][test]

coord(func:Callable[FlowField, ArgStar, KWArgs, Tensor], *args, **kwargs) → Image

No tests found for coord. To contribute a test please refer to this guide and this discussion.

Equivalent to image.flow = func(image.flow, image.size).

lighting[source][test]

lighting(func:LightingFunc, *args:Any, **kwargs:Any)

No tests found for lighting. To contribute a test please refer to this guide and this discussion.

Equivalent to image = sigmoid(func(logit(image))).

pixel[source][test]

pixel(func:LightingFunc, *args, **kwargs) → Image

No tests found for pixel. To contribute a test please refer to this guide and this discussion.

Equivalent to image.px = func(image.px).

refresh[source][test]

refresh()

No tests found for refresh. To contribute a test please refer to this guide and this discussion.

Apply any logit, flow, or affine transfers that have been sent to the Image.

resize[source][test]

resize(size:Union[int, TensorImageSize]) → Image

Tests found for resize:

  • pytest -sv tests/test_vision_image.py::test_image_resize_same_size_shortcut [source]

To run tests please refer to this guide.

Resize the image to size, size can be a single int.

save[source][test]

save(fn:PathOrStr)

No tests found for save. To contribute a test please refer to this guide and this discussion.

Save the image to fn.

class FlowField[source][test]

FlowField(size:Tuple[int, int], flow:Tensor)

No tests found for FlowField. To contribute a test please refer to this guide and this discussion.

Wrap together some coords flow with a size.