Introduction
이전에 CNN, Faster R-CNN, Mask R-CNN 모델을 사용하여 detection하고 recognition하는 것과 end-to-end 구조로 CNN은 semantic structure를 추출하고 NLP 모델은 text embedding의 이점을 가져오는 multimodal 모델이 있었다. 또한 비지니스 문서에서 정보를 추출하는 GCN(Graph Convolutional Network)도 알려져있다. 하지만 이러한 network들은 다음과 같은 한계를 가진다.
- 사람이 labling한 training sample에 의존한다.
- pre-trained CV model과 NLP model은 보통 좋은 성과를 가져다 주지만 textual, layout 정보를 학습하는 것은 고려되지 않았다.
LayoutLM은 BERT와 같이 text embedding과 position embedding을 통해 textual 정보를 표현한다. 그리고 input embedding은 두 가지 유형이 있다.
- 문서 내에 있는 token의 relative position을 나타내는 2-D position embedding
- 문서 내에 있는 스캔된 token 이미지의 image embedding
이렇게 두 가지 input embedding을 사용하는 이유는 2-D position embedding 같은 경우 문서 내에 있는 token들 사이의 관계를 잡아내고, 반면에 image embedding은 font, type, color와 같은 feature들을 잡아내기 때문이다.
또한 Masked Visual-Language Model (MVLM) loss와 Multi-label Document Classification (MDC) loss를 사용하여 text와 layout을 pre-train을 같이 하는 multi-task learning을 한다.
본 논문에서는 스캔된 문서 이미지를 pre-train하는 것에 집중하였다.
The LayoutLM Model
- Document Layout Information
문서 내의 word들의 relative position은 많은 의미적인 표현을 가지고 있다. 예를 들어서 여권에 "Passport ID:"와 같은 key가 주어진다면 이에 대응되는 value는 왼쪽이나 위에 있다기보다 아래나 오른쪽에 있을 가능성이 훨씬 높다. 그래서 relative position 정보를 2-D position representation으로 embedding한다.
- Visual Information
문서 단위에서 visual feature는 문서의 layout을 가리키고 문서 이미지 분류를 위해 중요한 feature가 된다. 그리고 word 단위에서 visual feature는 bold, underline, italic과 같은 style들이 sequence labeling task를 위한 중요한 hint가 된다. 그렇기 때문에 전통적인 text representation과 함께 image feature를 결합하는 것이 문서에서의 의미적인 표현을 더 많이 가져다 줄 수 있을 것이라 믿는다고 한다.
Model Architecture
- 2-D Position Embedding
2-D position embedding은 문서 내의 relative spatial position에 초점을 맞추었다. 이 모델은 스캔된 문서 이미지에있는 element들의 spatial position을 표현하기위해 문서 페이지에서 좌상(top-left) 점을 원점으로하여 좌표를 지정하였다. 이렇게해서 bounding box가 $(x_0, y_0, x_1, y_1)$으로 정밀하게 정의될 수 있다. 여기서 $(x_0, y_0)$는 bouning box의 좌상점이고 $(x_1, y_1)$는 bounding box의 우하점이다.
따라서 이 모델에서는 네 개의 position embedding과 두 개의 embedding table을 추가했다. 이것은 embedding table X에서$x_0, x_1$의 position embedding을 찾고 embedding table Y에서 $y_0, y_1$을 찾는다는 의미이다.
class LayoutLMEmbeddings(nn.Module):
"""Construct the embeddings from word, position and token_type embeddings."""
def __init__(self, config):
super(LayoutLMEmbeddings, self).__init__()
self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id)
self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size)
self.x_position_embeddings = nn.Embedding(config.max_2d_position_embeddings, config.hidden_size)
self.y_position_embeddings = nn.Embedding(config.max_2d_position_embeddings, config.hidden_size)
self.h_position_embeddings = nn.Embedding(config.max_2d_position_embeddings, config.hidden_size)
self.w_position_embeddings = nn.Embedding(config.max_2d_position_embeddings, config.hidden_size)
self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size)
self.LayerNorm = LayoutLMLayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.dropout = nn.Dropout(config.hidden_dropout_prob)
self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)))
def forward(
self,
input_ids=None,
bbox=None,
token_type_ids=None,
position_ids=None,
inputs_embeds=None,
):
if input_ids is not None:
input_shape = input_ids.size()
else:
input_shape = inputs_embeds.size()[:-1]
seq_length = input_shape[1]
device = input_ids.device if input_ids is not None else inputs_embeds.device
if position_ids is None:
position_ids = self.position_ids[:, :seq_length]
if token_type_ids is None:
token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device)
if inputs_embeds is None:
inputs_embeds = self.word_embeddings(input_ids)
words_embeddings = inputs_embeds
position_embeddings = self.position_embeddings(position_ids)
try:
left_position_embeddings = self.x_position_embeddings(bbox[:, :, 0])
upper_position_embeddings = self.y_position_embeddings(bbox[:, :, 1])
right_position_embeddings = self.x_position_embeddings(bbox[:, :, 2])
lower_position_embeddings = self.y_position_embeddings(bbox[:, :, 3])
except IndexError as e:
raise IndexError("The `bbox`coordinate values should be within 0-1000 range.") from e
h_position_embeddings = self.h_position_embeddings(bbox[:, :, 3] - bbox[:, :, 1])
w_position_embeddings = self.w_position_embeddings(bbox[:, :, 2] - bbox[:, :, 0])
token_type_embeddings = self.token_type_embeddings(token_type_ids)
embeddings = (
words_embeddings
+ position_embeddings
+ left_position_embeddings
+ upper_position_embeddings
+ right_position_embeddings
+ lower_position_embeddings
+ h_position_embeddings
+ w_position_embeddings
+ token_type_embeddings
)
embeddings = self.LayerNorm(embeddings)
embeddings = self.dropout(embeddings)
return embeddings
- Image Embedding
OCR 결과로 얻은 bouding box의 각 word와 여러 조각으로 이미지를 분할하고 그것들을 하나씩 대응시킨다. Faster R-CNN으로 얻은 이러한 이미지의 조각들을 token image embedding으로 사용하여 image region feature를 만든다. [CLS] 토큰의 경우 전체 스캔 문서 이미지를 RoI로 사용하여 embedding을 생성하기위해 Faster R-CNN을 사용한다고 한다. (huggingface에서 image embedding이 어떻게 구현되어있는지 찾아보려했지만 안보인다....)