Source code for ztlearn.utils.im2col_utils

# -*- coding: utf-8 -*-

import numpy as np

[docs]def get_pad(padding, input_height, input_width, stride_height, stride_width, kernel_height, kernel_width): if padding == 'valid': return (0, 0), (0, 0) elif padding == 'same': if (input_height % stride_height == 0): pad_along_height = max(kernel_height - stride_height, 0) else: pad_along_height = max(kernel_height - (input_height % stride_height), 0) if (input_width % stride_width == 0): pad_along_width = max(kernel_width - stride_width, 0) else: pad_along_width = max(kernel_width - (input_width % stride_width), 0) pad_top = pad_along_height // 2 pad_bottom = pad_along_height - pad_top pad_left = pad_along_width // 2 pad_right = pad_along_width - pad_left return (pad_top, pad_bottom), (pad_left, pad_right)
# Original Code: CS231n Stanford http://cs231n.github.io/assignments2017/assignment2/
[docs]def get_im2col_indices(x_shape, field_height = 3, field_width = 3, padding = ((0, 0), (0, 0)), stride = 1): # First figure out what the size of the output should be N, C, H, W = x_shape pad_height, pad_width = padding assert (H + np.sum(pad_height) - field_height) % stride == 0 assert (W + np.sum(pad_width) - field_height) % stride == 0 out_height = (H + np.sum(pad_height) - field_height) / stride + 1 out_width = (W + np.sum(pad_width) - field_width) / stride + 1 i0 = np.repeat(np.arange(field_height, dtype = 'int32'), field_width) i0 = np.tile(i0, C) i1 = stride * np.repeat(np.arange(out_height, dtype = 'int32'), out_width) j0 = np.tile(np.arange(field_width), field_height * C) j1 = stride * np.tile(np.arange(out_width, dtype = 'int32'), int(out_height)) i = i0.reshape(-1, 1) + i1.reshape(1, -1) j = j0.reshape(-1, 1) + j1.reshape(1, -1) k = np.repeat(np.arange(C, dtype='int32'), field_height * field_width).reshape(-1, 1) return (k, i, j)
[docs]def im2col_indices(x, field_height, field_width, padding, stride = 1): """ An implementation of im2col based on some fancy indexing """ pad_height, pad_width = padding x_padded = np.pad(x, ((0, 0), (0, 0), pad_height, pad_width), mode = 'constant') k, i, j = get_im2col_indices(x.shape, field_height, field_width, padding, stride) cols = x_padded[:, k, i, j] C = x.shape[1] cols = cols.transpose(1, 2, 0).reshape(field_height * field_width * C, -1) return cols
[docs]def col2im_indices(cols, x_shape, field_height = 3, field_width = 3, padding = ((0, 0), (0, 0)), stride = 1): """ An implementation of col2im based on fancy indexing and np.add.at """ N, C, H, W = x_shape pad_height, pad_width = padding H_padded, W_padded = H + np.sum(pad_height), W + np.sum(pad_width) x_padded = np.zeros((N, C, H_padded, W_padded), dtype = cols.dtype) k, i, j = get_im2col_indices(x_shape, field_height, field_width, padding, stride) cols_reshaped = cols.reshape(C * field_height * field_width, -1, N) cols_reshaped = cols_reshaped.transpose(2, 0, 1) np.add.at(x_padded, (slice(None), k, i, j), cols_reshaped) pad_size = (np.sum(pad_height)/2).astype(int) if pad_size == 0: return x_padded return x_padded[:, :, pad_size:-pad_size, pad_size:-pad_size]
pass