Получи случайную криптовалюту за регистрацию!

Очень часто модели компьютерного зрения при обработке картинок | DLStories | Нейронные сети и ИИ

Очень часто модели компьютерного зрения при обработке картинок обращают достаточно много внимания на фон, а не на сами объекты. Из-за этого модели получаются не очень устойчивы (robust), т.е. при малейших изменениях в распределениях картинок (domain shift) начинают хуже работать. Такому подвержены даже самые современные модели вроде vision transformers (ViT). Например, на первой картинке к посту показано, как модель ViT классифицирует картинку лимона как 'golf ball', потому что фон картинки — трава.

Статья "Optimizing Relevance Maps of Vision Transformers Improves Robustness" предлагает метод для решения этой проблемы. Он довольно простой и заключается вот в чем:

Берем предобученный ViT. Также берем по 3 изображения из 500 классов ImageNet (всего 1500 картинок), у которых есть сегментационная разметка вида объект-фон. Теперь дообучаем ViT 50 эпох на этих 1500 картиках, используя обычный классификационный лосс + дополнительный relevance лосс. Relevance loss заставляет модель фокусироваться преимущественно на пикселях объекта на картинке, и как можно меньше внимания обращать на пиксели фона.

Вычисляется relevance лосс очень просто (cм. 2 фото к посту). Здесь S(i) — segmentation map изображения i, S(i) с чертой — inverse segmentation map, R(i) — rclass elevance map. Class relevance map — это карта, отражающая релевантность пикселей картинки классам задачи. Идея class relevance map позаимствована из статьи "Generic Attention-model Explainability for Interpreting Bi-Modal and Encoder-Decoder Transformers", в которой авторы предлагают новый метод для интерпретации работы (explainability) transformer-like моделей.

Эта идея действительно улучшает robustness различных моделей ViT, в среднем увеличивая top-1 accuracy на 5% на датасетах INet-A, ObjectNet. На третьей картинке к посту показаны того, как идея статьи помогает исправить предсказания модели.
При этом идея довольно простая в исполнении: нужно всего по паре-тройке сегментированных картинок из каждого класса и чуть времени на дообучение. Более того, карты сегментации для стадии дообучения можно получать в unsupervised режиме, с помощью метода для object localization, предложенного в этой статье. В таком unsupervised режиме дообучение также улучшает accuracy.

Больше результатов на разных датасетах в статье на arxiv.
Также к статье есть отлично структурированный код на GitHub с демо на колабе