Deep Learning

[Paper] VITRUVION: A GENERATIVE MODEL OF PARAMETRIC CAD SKETCHES (2)

ju_young 2023. 7. 3. 00:14
728x90

Constraint Model

  • constraint sequence의 autoregressive generation을 수행

  • $N_C$: constraint의 수
  • 각 constraint는 (type, parameters)의 형태를 가지는 tuple을 가짐
  • 2개 이하의 reference parameter를 가지는 모든 constraint를 다룸 (숫자 parameter 이거나 3개 이상의 reference를 가질 경우는 제외)

Ordering

constraint는 primitive에 따라 sorting된다.

# https://github.com/PrincetonLIPS/vitruvion/blob/1b91fff5597b3a4e272e6490e4d458f3ef790e62/img2cad/constraint_data.py#L161C5-L165C44
def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]:
    seq = self.sequences[idx]

    sketch = datalib.sketch_from_sequence(seq)
    data_utils.normalize_sketch(sketch)

위 코드에서는 sketch_from_sequence와 primitive dataset에서처럼 normalizaion을 수행하는 부분 밖에 안보인다. 아무래도 sketgraph에서 구현한 기능을 그대로 사용한 것 같은데 왜 따로 Ordering이라는 부분을 작성해놓은건지 모르겠다. 심지어 similar to Seff et al. (2020) 이라고 sketchgraphs를 reference로 걸어놨다.

Noise injection

  • primitive model에서의 generation에 noise(가우스 노이즈)를 적용함으로써 constraint model에서 그 부분을 조정
# https://github.com/PrincetonLIPS/vitruvion/blob/1b91fff5597b3a4e272e6490e4d458f3ef790e62/img2cad/constraint_data.py#L167C9-L168C128
if self.primitive_noise_config.enabled:
    sketch = apply_primitive_noise(sketch, self.primitive_noise_config.std, self.primitive_noise_config.max_difference)
    
# https://github.com/PrincetonLIPS/vitruvion/blob/1b91fff5597b3a4e272e6490e4d458f3ef790e62/img2cad/constraint_data.py#L119-L125
def apply_primitive_noise(sketch: datalib.Sketch, std: float=0.15, max_difference: float=0.15) -> datalib.Sketch:
    noise_sketch = copy.deepcopy(sketch)
    try:
        noise_models.noisify_sketch_ents(noise_sketch, std=std, max_diff=max_difference)
    except:
        noise_sketch = sketch
    return noise_sketch

# https://github.com/PrincetonLIPS/vitruvion/blob/1b91fff5597b3a4e272e6490e4d458f3ef790e62/img2cad/noise_models.py#L188C1-L209
def noisify_sketch_ents(sketch: Sketch, std: float=0.2, max_diff: float=0.1):
    for ent_key, ent in sketch.entities.items():
        new_ent = _trunc_normal_entity_noise(ent, std, max_diff)
        new_ent.isConstruction = ent.isConstruction
        sketch.entities[ent_key] = new_ent
    return sketch

Tokenization

  • primitive model과 비슷하게 tokenization을 진행
# https://github.com/PrincetonLIPS/vitruvion/blob/1b91fff5597b3a4e272e6490e4d458f3ef790e62/img2cad/constraint_data.py#L173C7-L174
sample, gather_idx = dataset.tokenize_sketch(sketch, self.num_bins, self.max_length)
c_sample = tokenize_constraints(seq, gather_idx, self.max_length)

tokenize_sketch에서 얻은 gather_idx는 각 entity(arc, line, circle 등)의 reference parameter index를 의미한다. 그리고 이 gather_idx를 사용해서 각 constraint의 reference parameter를 추가한다. 이때 "parameter 값 + constraint type 수" 를 해준다.

최종적으로 (value, ID, position token)의 tuple 형태가 나오고 각각 parameter value, parameter type, 정렬된 constraint 위치를 가리킨다. 하지만 실제 코드에서는 val_tokens, coord_tokens, pos_tokens라는 변수를 사용하여 헷갈리게 구현해놨다. (살짝 짜증남)

# https://github.com/PrincetonLIPS/vitruvion/blob/1b91fff5597b3a4e272e6490e4d458f3ef790e62/img2cad/constraint_data.py#L101C9-L105C49
# Add reference parameters
val_tokens.extend(
    [gather_idxs[ref] + len(Token) for ref in sorted(refs)])
coord_tokens.extend(CONSTRAINT_COORD_TOKENS[:len(refs)])
pos_tokens.extend([pos_idx] * len(refs))

coord_tokens는 constraint의 reference의 수에 따라 유형이 나누어지는 것을 나타낸다. 만약 reference가 1개라면 [1]이 추가되고 2개면 [1, 2]가 추가된다.

Architecture

  • encoder-decoder transformer
# https://github.com/PrincetonLIPS/vitruvion/blob/1b91fff5597b3a4e272e6490e4d458f3ef790e62/img2cad/constraint_models.py#L71-L92
# Primitive model (for dynamic embeddings)
self.prim_model = modules.PrimitiveModel(num_bins, max_entities, embed_dim,
    fc_size, num_heads, num_layers, dropout, use_mask=False,
    linear_decode=False)
pad_tok = constraint_data.Token.Pad
# Value embeddings (only fixed ones)
num_val_embeddings = len(constraint_data.Token)
self.val_embed = torch.nn.Embedding(num_val_embeddings, embed_dim,
    padding_idx=pad_tok)
# Coordinate embeddings
num_coord_embeddings = 2 + len(
    constraint_data.CONSTRAINT_COORD_TOKENS)
self.coord_embed = torch.nn.Embedding(num_coord_embeddings, embed_dim,
    padding_idx=pad_tok)
# Position embeddings
num_pos_embeddings = 3 + (4 * max_entities)  # see make_sequence_dataset
self.pos_embed = torch.nn.Embedding(num_pos_embeddings, embed_dim,
    padding_idx=pad_tok)  # TODO: dry-ify overlapping logic w/ PrimModel
# Transformer decoder
decoder_layers = torch.nn.TransformerDecoderLayer(embed_dim, num_heads, fc_size,
    dropout)
self.trans_decoder = torch.nn.TransformerDecoder(decoder_layers, num_layers)

이전 포스트에서 primitive model과 constraint model을 나누었다는 의미를 위 코드에서 알 수 있다. primitive model을 encoder로 사용하고 decoder를 추가한 모습이다.

Embeddings

primitive model에 있는 TransformerModel에는 TransformerEncoder가 존재하고 이 encoder는 pirimitive token을 embedding한다. 이때는 linear layer와 softmax를 거치지 않는다.

primitive embedding은 위에서 언급한 val_tokens의 embedding과 concat된다.

그리고 최종적으로 val_embedding + coor_embedding + pos_embedding이 constraint embedding으로 사용한다.

# https://github.com/PrincetonLIPS/vitruvion/blob/1b91fff5597b3a4e272e6490e4d458f3ef790e62/img2cad/constraint_models.py#L94C5-L110C43
def _embed_tokens(self, src):
    # Embed primitives
    prim_embeddings = self.prim_model(src)
    # Prepend fixed val embeddings (constraint types)
    batch_size = src['c_val'].shape[0]
    fixed_tokens = np.tile(range(len(constraint_data.Token)), (batch_size, 1))
    fixed_tokens = torch.tensor(fixed_tokens).to(src['c_val'].device)
    fixed_val_embeddings = self.val_embed(fixed_tokens)
    prim_embeddings = torch.cat([fixed_val_embeddings, prim_embeddings], 1)
    # primitive embedding과 value embedding을 concat
    val_tokens = torch.unsqueeze(src['c_val'], 2).expand(
        -1, -1, prim_embeddings.shape[2])
    # Embed constraint tokens
    val_embeddings = torch.gather(prim_embeddings, 1, val_tokens)
    coord_embeddings = self.coord_embed(src['c_coord'])
    pos_embeddings = self.pos_embed(src['c_pos'])
    embeddings = val_embeddings + coord_embeddings + pos_embeddings
    return embeddings, prim_embeddings

Context Conditioning

Primer-conditional generation

primer는 그리다만 sketch를 표현하는 primitive sequence를 말한다. 그리고 이 sketch를 train된 primitive model에 input으로 들어가고 나머지 primitive는 stop token이 sampling될때까지 sampling한다.

# https://github.com/PrincetonLIPS/vitruvion/blob/1b91fff5597b3a4e272e6490e4d458f3ef790e62/img2cad/evaluation/sample_primitives_primed.py#L33
def complete_sketch(model: primitives_models.RawPrimitiveModule, sketch: datalib.Sketch) -> datalib.Sketch:
    """Completes the given sketch using the provided model.
    """
    num_position_bins = model.hparams.data.num_position_bins

    tok_input, _ = dataset.tokenize_sketch(sketch, num_position_bins, include_stop=False)
    tok_completed = sample_img2prim.sample(model.model, 130, tok_input=tok_input)

    completed_sketch = dataset.sketch_from_tokens(tok_completed, num_position_bins)
    return completed_sketch

Image-conditional generation

Primer-confitional generation에서 input에 이미지가 추가된 것으로 ViT와 비슷하게 먼저 이미지를 patch로 나누어 flatten한다. transformer encoder는 context embedding sequence를 학습하고 primitive decoder가 image context embedding으로 cross-attention을 수행한다.

코드에서는 image embedding을 위해 mobileNetV3를 사용하기도 하고 이미지를 Unfold를 사용하여 patch로 나눈 후 embedding하는 모델이 있다. 참고

Experiments

Image-confitional sample

맨 위의 손인지 발인지로 그린 sketch가 input이고 그 밑에 4개가 sampling된 결과이다. 몇 개는 sketch가 잘 나오고 몇 개는 이상하게 나오는 것을 확인할 수 있다. hand-drawn이 input으로 들어가는 경우는 큰 기대 안했다. (어차피 실용적인 면에서나 정밀도의 측면에서 사용하기에 극도로 꺼려진다.)

Primer-conditional generation

왼쪽에 작성되어있는 것처럼 Original이 원래 sketch이고 Primer가 일부분 지워버린 것이다. 오른쪽 각 6개의 sketch가 inference 결과인데 예측하기 어려운 경우가 무엇인지 확연하게 알기가 쉽다. 2, 4 번쨰의 경우 원래 형상을 사람이 봐도 예측할 수 없는 경우 모델도 예측하기 어려워한다. 다만 저러한 경우의 sketch도 여러 형상을 학습시킨다면 가능할지도 모른다. data hungry....

Constraining primitives

사실 여러 CAD 프로그램에서 auto constraint 기능은 대부분 가지고 있어서 관심은 없지만 결과는 위와 같다. 오른쪽 두 번째는 noise 없이 학습시킨 것이고 가장 오른쪽은 noise aumentation을 적용한 model을 사용했을 경우이다.

본 논문의 마지막에 "AutoCAD의 auto-constrain tool은 유용한 tool이지만 각 constraint type에 오차와 우선 순위를 직접 설정해야한다. 하지만 우리가 제안한 모델은 자동으로 adapt할 수 있다." 라고 하는데 이 부분은 완벽하게 동의할 수 없다. 자동으로 설정하고 constraint를 넣은들 잘못된 결과가 나온다면 사용자가 그 모델을 사용하지 않을 것이다. 또한 setting(설정)은 한 번 초기에 하고나서 바꿀일이 많이 없다. 여러 분야를 통달하는 대단한 분이 아니면 한 번 설정하고나서 건드릴 일이 많이 없다.

728x90