Source code for variable_data_loader

import random
import torch
from torch.utils.data import DataLoader, TensorDataset

[docs]class VariableDataLoader(object): """Load data from variable length inputs Attributes ---------- lengths : dict() Dictionary of input-length -> input samples index : boolean, default=False If True, also returns original index batch_size : int, default=1 Size of each batch to output shuffle : boolean, default=True If True, shuffle the data randomly, each yielded batch contains only input items of the same length """
[docs] def __init__(self, X, y, index=False, batch_size=1, shuffle=True): """Load data from variable length inputs Parameters ---------- X : iterable of shape=(n_samples,) Input sequences Each item in iterable should be a sequence (of variable length) y : iterable of shape=(n_samples,) Labels corresponding to X index : boolean, default=False If True, also returns original index batch_size : int, default=1 Size of each batch to output shuffle : boolean, default=True If True, shuffle the data randomly, each yielded batch contains only input items of the same length """ # Get inputs by length self.lengths = dict() # Loop over inputs for i, (X_, y_) in enumerate(zip(X, y)): X_length, y_length, i_length = self.lengths.get(len(X_), (list(), list(), list())) X_length.append(X_) y_length.append(y_) i_length.append(i) self.lengths[len(X_)] = (X_length, y_length, i_length) # Transform to tensors for k, v in self.lengths.items(): self.lengths[k] = ( torch.as_tensor(v[0]), torch.as_tensor(v[1]), torch.as_tensor(v[2]) ) # Set index self.index = index # Set batch_size self.batch_size = batch_size # Set shuffle self.shuffle = shuffle # Reset self.reset() # Get keys self.keys = set(self.data.keys())
[docs] def reset(self): """Reset the VariableDataLoader""" # Reset done self.done = set() # Reset DataLoaders self.data = { k: iter(DataLoader( TensorDataset(v[0], v[1], v[2]), batch_size = self.batch_size, shuffle = self.shuffle)) for k, v in self.lengths.items() }
def __iter__(self): """Returns iterable of VariableDataLoader""" # Reset self.reset() # Return self return self def __next__(self): """Get next item of VariableDataLoader""" # Check if we finished the iteration if self.done == self.keys: self.reset() # Stop iterating raise StopIteration # Select key if self.shuffle: key = random.choice(list(self.keys - self.done)) else: key = sorted(self.keys - self.done)[0] # Yield next item in batch try: X_, y_, i = next(self.data.get(key)) if self.index: item = (X_, y_, i) else: item = (X_, y_) except StopIteration: # Add key self.done.add(key) # Get item iteratively item = next(self) # Return next item return item