Получи случайную криптовалюту за регистрацию!

Хочу рассказать про Gumbel straight-through estimator, ибо я с | Матчасть

Хочу рассказать про Gumbel straight-through estimator, ибо я сам его наконец-то понял
Этот трюк пропускает градиент через дискретное представление, а что это такое, я сейчас объясню. Например, мы хотим обучить GAN: генератор порождает контент, дискриминатор оценивает его качество, и генератор обновляет свои веса в направлении улучшения качества. Это хорошо работает с картинками, потому что они непрерывные: генератор может на эпсилон подкрутить цвет каждого пикселя и картинка чуть улучшится. Математически это возможно потому, что мы можем взять производную сгенерированной картинки по параметрам генератора, и делать с её помощью градиентный шаг.

А как быть, если наш генератор генерирует тексты? Проблема текста, что он - тупо последовательность слов. У каждого слова есть какой-то номер в словаре, и фраза, например, "привет мир" может кодироваться как [14050, 5840, 1] (здесь единичка - это символ конца текста). И вот как это дифференцировать? Следите за руками!

1. Превращаем текст в sparse представление: матрицу размера (text_length, vocab_size), с единицами в позициях соответствующих слов, и нулями в остальных местах. Такая матрица выглядит уже чуть более дифференцировабельной. Но какая у неё может быть производная?
2. Вообще-то генератор текста сэмплирует его из предсказанного им распределения: softmax(logits), где logits - предсказания генератора. Их-то мы точно умеем дифференцировать, а вот оператор случайного выбора - недифференцируемый.
3. Оказывается, существует распределение Гумбеля, обладающее полезным свойством: распределения величин sample(softmax(logits)) и argmax(logits + gumbel_random()) - совпадают! В первом случае мы считаем вероятности каждого токена и случайно выбираем токен в соответствии с этими вероятностями. Во втором мы к скору каждого токена прибавляем гумбелевскую случайную величину, и выбираем токен с максимальным результатом. И эти процедуры - эквивалентны.
4. argmax всё ещё не дифференцируемая функция, но дифференцируемо её приближение: softmax. Получается, наша разреженная матрица hard_scores из пункта (1) примерно равна soft_scores=softmax(logits + gumbel_random()), и производную этой штуки мы уже умеем вычислять.
5. Как использовать матрицу hard_scores, прилепив к ней производную матрицы soft_scores? В pytorch это делается так: вычтем из разреженной матрицы приближенную, сбросим градиент этой разницы, и потом прибавим приближенную матрицу обратно. То есть выдадим (hard_scores - soft_scores).detach() + soft_scores.

Собственно, вот исходный код этой нечисти: https://pytorch.org/docs/stable/_modules/torch/nn/functional.html#gumbel_softmax. Там ещё используется температура в софтмаксе: чем она меньше, тем ближе softmax к argmax. С подбором температуры для этого кейса я не экспериментировал, но если экспериментировали вы, то делитесь своими находками в комментах)