Dynamic Batching for Training Large Sequence Models (LLMs)
Preliminaries
To maximise GPU memory when training large models, we want to pack tokens such that sequence padding is minimised and GPU memory is maximised.
torch.utils.data.dataloader
is an python iterable over a PyTorch datasettorch.utils.data.dataset
implements__getitem()__
, which maps keys to data samples.torch.utils.data.sampler
specifies the sequences of keys used in data loading.
By default, the DataLoader
will collate individual fetched samples into batches using the arguments batch_size
, drop_last
, batch_sampler
, and collate_fn
. An alternatively, if batch_size
is None, we can construct a BatchSampler
which yields a list of keys at a time.
Default Approach
We have several options, starting with the default.
The most default thing to do is to pad every sequence to the maximimum context window, and return a fixed batch size. However, this is incredibly wasteful. Imagine a batch size of 2, where we have a sequence X1 of length 10 and sequence X2 of length 1000 in the same batch. Sequence X1 will be padded for 990 token positions, which is nearly 50% wasted GPU memory.
Concatenate and Slice
The second approach, is to concatenate all data together, and slice the long sequence into smaller sequences, each with the length equal to the maximum context window. This is more space efficient, but may create the wrong long-range dependencies, if we had concatenated many non-related small sequences together.
Dynamic Batching Method
The idea behind dynamic batching, is that we’re going to maximise the GPU memory to avoid the wastefulness of the default approach, yet not introduce spurious relationships by concatenating unrelated sequences. Intuitively, we want short sentences to be grouped together and have a larger batch size, and long sequences to be grouped together with a smaller batch size.
- Create 4 bins, with ranges $(0, L/4)$, $(L/4, L/2)$, $(L/2, L*(3/4))$, $(L * (3/4), L)$, where $L$ is the maximum sequence length or maximum context window.
def construct_bins(self):
print("Constructing bins", len(self.seq_lengths))
for i, val in tqdm(enumerate(self.seq_lengths)):
if val < (self.max_seq_len / 4):
self.bins['small'].append(i)
elif val < (self.max_seq_len / 2):
self.bins['med'].append(i)
elif val < ((self.max_seq_len / 4) * 3):
self.bins['large'].append(i)
else:
self.bins['xl'].append(i)
- Sample from these bins proportional to how much data is in these bins. To ensure we have seen all the data, we keep track of the current index of the bin, and when the current index exceeds the bin size, reset it to 0, and shuffle the bin. Note that the BatchSampler must yield a list of values, unlike Sampler (non-batch) which yields from a iterable.
def __iter__(self):
while True:
choices = [('small', self.med_batchsize*2), ('med', self.med_batchsize), ('large', (self.med_batchsize/2)*3), ('xl', self.med_batchsize//2)]
weights=[len(self.bins['small']), len(self.bins['med']), len(self.bins['large']) len(self.bins['xl'])]
bin_type, size = random.choices(choices, weights, k=1)[0]
print("SAMPLER ITER CALLED", bin_type, size)
cur_index = self.current_index[bin_type]
self.current_index[bin_type] += size
if self.current_index[bin_type] > len(self.bins[bin_type]):
self.current_index[bin_type] = 0
random.shuffle(self.bins[bin_type])
yield self.bins[bin_type][cur_index:cur_index + size]
We want to interface with HuggingFace trainer and Pytorch Dataloader, and therefore encapsulate our methods inside a Custom class that inherits from BatchSampler
The complete class looks like this
class DynamicBatchSampler(BatchSampler):
def __init__(self, seq_lengths, max_seq_len=1024, med_batchsize=64):
self.med_batchsize = med_batchsize
self.seq_lengths = seq_lengths
self.bins = {"small": [], "med": [], "large": [], 'xl':[]}
self.max_seq_len = max_seq_len
self.construct_bins()
self.current_index = {"small": 0, "med": 0, "large": 0, 'xl':0}
def __iter__(self):
...
def construct_bins(self):
...
def __len__(self):
# Return an estimate of the number of batches
return sum([
len(self.bins['small']) // (self.med_batchsize*2 + 1),
len(self.bins['med']) // (self.med_batchsize + 1),
len(self.bins['large']) // (self.med_batchsize//2 * 3 + 1),
len(self.bins['xl']) // (self.med_batchsize//2 + 1)
]) // 4
HuggingFace Trainer Compatibility
To make this compatible with HuggingFace trainer, we subclass the _get_train_sampler()
method to return our newly constructed DynamicBatchSampler.
class MyTrainer(Trainer):
def __init__(self, **kwargs):
super().__init__(**kwargs)
def _get_train_sampler(self):
dynamic_sampler = DynamicBatchSampler((trainer.train_dataset['lengths']), max_seq_len=1024, med_batchsize=64)
return dynamic_sampler
def get_train_dataloader(self):
# override this method
...
dataloader_params = {"collate_fn": data_collator,
"num_workers": self.args.dataloader_num_workers,
"pin_memory": self.args.dataloader_pin_memory,
"shuffle": False,
"sampler": None,
"batch_sampler": self._get_train_sampler(),
"drop_last": False,
"persistent_workers": self.args.dataloader_persistent_workers}
dataloader = DataLoader(self.train_dataset, **dataloader_params)
return self.accelerator.prepare(DataLoader(self.train_dataset, **dataloader_params))
Then we can easily do trainer = MyTrainer(..); trainer.train()
.
Implementation Note:
It’s necessary to override the get_train_dataloader()
method to have full control over the dataloader_params. For instance, because we constructed a custom BatchSampler, the batch_size
argument given to trainer should be empty or there will be an error thrown regarding a conflict in batch_size
number. This defaults to 1 in Pytorch but defaults to 8 in HuggingFace Trainer.