필수 과제 4

헷갈렸던 부분만 정리

2. Bucketing

Figure 1. Bucketing을 적용하지 않은 경우

Untitled

Figure 2. Bucketing을 적용한 경우

Untitled

def bucketed_batch_indices(
    sentence_length: List[Tuple[int, int]],
    batch_size: int,
    max_pad_len: int
) -> List[List[int]]:
    """ Function for bucketed batch indices
    Although the loss calculation does not consider PAD tokens,
    it actually takes up GPU resources and degrades performance.
    Therefore, the number of <PAD> tokens in a batch should be minimized in order to maximize GPU utilization.
    Implement a function which groups samples into batches that satisfy the number of needed <PAD> tokens in each sentence is less than or equals to max_pad_len.
    
    Note 1: several small batches which have less samples than batch_size are okay but should not be many. If you pass the test, it means "okay".

    Note 2: you can directly apply this function to torch.utils.data.dataloader.DataLoader with batch_sampler argument.
    Read the test codes if you are interested in.

    Arguments:
    sentence_length -- list of (length of source_sentence, length of target_sentence) pairs.
    batch_size -- batch size
    max_pad_len -- maximum padding length. The number of needed <PAD> tokens in each sentence should not exceed this number.

    return:
    batch_indices_list -- list of indices to be a batch. Each element should contain indices of sentence_length list.

    Example:
    If sentence_length = [7, 4, 9, 2, 5, 10], batch_size = 3, and max_pad_len = 3,
    then one of the possible batch_indices_list is [[0, 2, 5], [1, 3, 4]]
    because [0, 2, 5] indices has simialr length as sentence_length[0] = 7, sentence_length[2] = 9, and sentence_length[5] = 10.
    """   
    batch_map = defaultdict(list)    
    batch_indices_list = []    
    src_len_min = min(sentence_length, key=lambda t: t[0])[0] # 첫번째 인덱스인 src의 min length
    tgt_len_min = min(sentence_length, key=lambda t: t[1])[1] # 두번째 인덱스인 tgt의 min length
    for idx, (src_len, tgt_len) in enumerate(sentence_length):
        src = (src_len - src_len_min + 1) // (5) # max_pad_len 단위로 묶어주기 위한 몫 (그림에서는 5)
        tgt = (tgt_len - tgt_len_min + 1) // (5) # max_pad_len 단위로 묶어주기 위한 몫 (그림에서는 5)
        batch_map[(src, tgt)].append(idx) 
    
    for key, value in batch_map.items():
        batch_indices_list += [value[i: i+batch_size] for i in range(0, len(value), batch_size)]    
    ### 코드 작성 완료

    # Don't forget shuffling batches because length of each batch could be biased
    random.shuffle(batch_indices_list)

    return batch_indices_list

Master Class - 주재걸 교수님