ABOUT ME

-

Today
-
Yesterday
-
Total
-
  • Object-Centric Learning with Slot Attention
    논문 정리 2023. 3. 25. 18:29
    반응형

    이 논문에서는 CNN을 사용하여 복잡한 장면의 객체 중심 표현(object-centric representation)을 학습하는 새로운 방법을 소개한다. 이 방법 슬롯이라고 하는 task-dependent abstract representation을 생성하는 slot attention module을 사용한다. 이러한 슬롯은 반복적인 attention 과정을 통해 업데이트가 가능하며 모든 입력 feature와 상호작용한다. 이 논문은 slot attention이 unsupervised object discovery 및 supervised property prediction task에 대해 학습할 때 보이지 않는 구성에 대한 일반화를 가능하게 하는 객체 중심 표현을 추출할 수 있음을 보여준다. 또한 이 논문에서는 이 방법을 객체 표현을 위한 기존의 딥러닝 접근 방식과 비교하고 향후 연구 방향을 제시한다.

     

    1. Object-centric learning

    객체 중심 학습은 장면이나 환경에 있는 객체의 표현을 학습하는 데 중점을 두는 딥러닝의 한 유형이다. 이 task의 목표는 장면에서 개별 개체를 식별하고 표현함으로써 low-level perceptual feature으로부터 abstract reasoning을 가능하게 하는 것이다. 이 접근 방식은 시각적 추론, 구조화된 환경 모델링, 다중 에이전트 모델링, 상호 작용하는 물리적 시스템 시뮬레이션 등 다양한 애플리케이션 영역에서 머신러닝 알고리즘의 샘플 효율성과 일반화를 개선할 수 있는 잠재력을 가지고 있다. 그러나 이미지나 비디오와 같은 raw perceptual 입력에서 객체 중심의 표현을 얻는 것은 어렵고 supervision 학습 또는 task-specific architecture가 필요한 경우가 많다.

     

    2. Slot attention

     

    슬롯 어텐선의 구현은 아래와 같다. (https://github.com/lucidrains/slot-attention)

    import torch
    from torch import nn
    from torch.nn import init
    
    class SlotAttention(nn.Module):
        def __init__(self, num_slots, dim, iters = 3, eps = 1e-8, hidden_dim = 128):
            super().__init__()
            self.num_slots = num_slots
            self.iters = iters
            self.eps = eps
            self.scale = dim ** -0.5
    
            self.slots_mu = nn.Parameter(torch.randn(1, 1, dim))
    
            self.slots_logsigma = nn.Parameter(torch.zeros(1, 1, dim))
            init.xavier_uniform_(self.slots_logsigma)
    
            self.to_q = nn.Linear(dim, dim)
            self.to_k = nn.Linear(dim, dim)
            self.to_v = nn.Linear(dim, dim)
    
            self.gru = nn.GRUCell(dim, dim)
    
            hidden_dim = max(dim, hidden_dim)
    
            self.mlp = nn.Sequential(
                nn.Linear(dim, hidden_dim),
                nn.ReLU(inplace = True),
                nn.Linear(hidden_dim, dim)
            )
    
            self.norm_input  = nn.LayerNorm(dim)
            self.norm_slots  = nn.LayerNorm(dim)
            self.norm_pre_ff = nn.LayerNorm(dim)
    
        def forward(self, inputs, num_slots = None):
            b, n, d, device, dtype = *inputs.shape, inputs.device, inputs.dtype
            n_s = num_slots if num_slots is not None else self.num_slots
            
            mu = self.slots_mu.expand(b, n_s, -1)
            sigma = self.slots_logsigma.exp().expand(b, n_s, -1)
    
            slots = mu + sigma * torch.randn(mu.shape, device = device, dtype = dtype)
    
            inputs = self.norm_input(inputs)        
            k, v = self.to_k(inputs), self.to_v(inputs)
    
            for _ in range(self.iters):
                slots_prev = slots
    
                slots = self.norm_slots(slots)
                q = self.to_q(slots)
    
                dots = torch.einsum('bid,bjd->bij', q, k) * self.scale
                attn = dots.softmax(dim=1) + self.eps
    
                attn = attn / attn.sum(dim=-1, keepdim=True)
    
                updates = torch.einsum('bjd,bij->bid', v, attn)
    
                slots = self.gru(
                    updates.reshape(-1, d),
                    slots_prev.reshape(-1, d)
                )
    
                slots = slots.reshape(b, -1, d)
                slots = slots + self.mlp(self.norm_pre_ff(slots))
    
            return slots

     

    기본적으로 "query", "key", "value" 기반의 어텐션 메커니즘을 사용한다. 우선 입력 feature를 key, value로 만들고 랜덤 초기화된 슬롯들을 query로 사용한다. 슬롯 어텐션은 "iters" 만큼 반복되면서 (논문의 경우 3) 슬롯을 점진적으로 업데이트 한다. 입력 및 출력의 크기는 아래 코드를 참고하면 된다.

     

    import torch
    from slot_attention import SlotAttention
    
    slot_attn = SlotAttention(
        num_slots = 5,
        dim = 512,
        iters = 3   # iterations of attention, defaults to 3
    )
    
    inputs = torch.randn(2, 1024, 512)
    slot_attn(inputs) # (2, 5, 512)

     

    논문은 object-centric learning 기반의 세그멘테이션 작업을 수행한다. 결과는 아래와 같다.

    각각의 슬롯에는 개별의 물체에 대한 정보가 저장되며 이 과정은 모두 unsupervised로 이루어진다. 이 세그멘테이션 과정이 슬롯으로 부터 어떻게 이루어 지는지는 아래 코드를 보면 쉽게 알수있다. (https://github.com/evelinehong/slot-attention-pytorch)

     

    # Slot Attention module.
    slots = self.slot_attention(x)
    # `slots` has shape: [batch_size, num_slots, slot_size].
    
    # """Broadcast slot features to a 2D grid and collapse slot dimension.""".
    slots = slots.reshape((-1, slots.shape[-1])).unsqueeze(1).unsqueeze(2)
    slots = slots.repeat((1, 8, 8, 1))
    
    # `slots` has shape: [batch_size*num_slots, width_init, height_init, slot_size].
    x = self.decoder_cnn(slots)
    # `x` has shape: [batch_size*num_slots, width, height, num_channels+1].
    
    # Undo combination of slot and batch dimension; split alpha masks.
    recons, masks = x.reshape(image.shape[0], -1, x.shape[1], x.shape[2], x.shape[3]).split([3,1], dim=-1)
    # `recons` has shape: [batch_size, num_slots, width, height, num_channels].
    # `masks` has shape: [batch_size, num_slots, width, height, 1].
    
    # Normalize alpha masks over slots.
    masks = nn.Softmax(dim=1)(masks)
    recon_combined = torch.sum(recons * masks, dim=1)  # Recombine image.
    recon_combined = recon_combined.permute(0,3,1,2)
    # `recon_combined` has shape: [batch_size, width, height, num_channels].

     

    처음 모델 그림에서 슬롯 개수 만큼의 여러개의 디코더를 사용하는 것 처럼 표현되어있지만 실은 batch_size x num_slots로 슬롯을 배치 단위로 묶어 연산한다. 모델은 오토 인코더와 유사하게 구성되며 이렇게 하면 3채널의 recons 이미지와 각각에 대한 마스크가 알파 채널로써 생성되는데, recon_combined = torch.sum(recons * masks, dim=1)으로 원래 이미지를 재구성하여 입력 이미지와 loss를 계산한다. 결과적으로 unsupervised로 masks 학습이 가능해진다.

    반응형

    댓글

Designed by black7375.