Figure 1. Bucketing을 적용하지 않은 경우
Figure 2. Bucketing을 적용한 경우
def bucketed_batch_indices(
sentence_length: List[Tuple[int, int]],
batch_size: int,
max_pad_len: int
) -> List[List[int]]:
""" Function for bucketed batch indices
Although the loss calculation does not consider PAD tokens,
it actually takes up GPU resources and degrades performance.
Therefore, the number of <PAD> tokens in a batch should be minimized in order to maximize GPU utilization.
Implement a function which groups samples into batches that satisfy the number of needed <PAD> tokens in each sentence is less than or equals to max_pad_len.
Note 1: several small batches which have less samples than batch_size are okay but should not be many. If you pass the test, it means "okay".
Note 2: you can directly apply this function to torch.utils.data.dataloader.DataLoader with batch_sampler argument.
Read the test codes if you are interested in.
Arguments:
sentence_length -- list of (length of source_sentence, length of target_sentence) pairs.
batch_size -- batch size
max_pad_len -- maximum padding length. The number of needed <PAD> tokens in each sentence should not exceed this number.
return:
batch_indices_list -- list of indices to be a batch. Each element should contain indices of sentence_length list.
Example:
If sentence_length = [7, 4, 9, 2, 5, 10], batch_size = 3, and max_pad_len = 3,
then one of the possible batch_indices_list is [[0, 2, 5], [1, 3, 4]]
because [0, 2, 5] indices has simialr length as sentence_length[0] = 7, sentence_length[2] = 9, and sentence_length[5] = 10.
"""
batch_map = defaultdict(list)
batch_indices_list = []
src_len_min = min(sentence_length, key=lambda t: t[0])[0] # 첫번째 인덱스인 src의 min length
tgt_len_min = min(sentence_length, key=lambda t: t[1])[1] # 두번째 인덱스인 tgt의 min length
for idx, (src_len, tgt_len) in enumerate(sentence_length):
src = (src_len - src_len_min + 1) // (5) # max_pad_len 단위로 묶어주기 위한 몫 (그림에서는 5)
tgt = (tgt_len - tgt_len_min + 1) // (5) # max_pad_len 단위로 묶어주기 위한 몫 (그림에서는 5)
batch_map[(src, tgt)].append(idx)
for key, value in batch_map.items():
batch_indices_list += [value[i: i+batch_size] for i in range(0, len(value), batch_size)]
### 코드 작성 완료
# Don't forget shuffling batches because length of each batch could be biased
random.shuffle(batch_indices_list)
return batch_indices_list