2021-01-19 08:09:09
В качестве корпуса для предобучения взяли C4 - прочищенный common crawl на 180 миллиардов токенов. Switch Transformer учили как Masked Language Model, маскируя 15% выходных токенов, по сути BERT-like трейн. Помимо кроссэнтропийной лосс-функции добавили еще один лосс, который форсит роутер равномерно раскидывать токены по имеющимся экспертам, чтобы не происходило того, что одни эксперты перегружены вычислениями, а другие недогружены.
Хаки для улучшения сходимости и качества:
- selective precision. Раньше большие MoE модели обучали float32, а с пониженной точностью обучение было нестабильным. Тут показали, что можно всю модель учить в bfloat16 и делать каст во float32 только для инпутов роутера. В результате все all-to-all операции по агрегации тензоров делаются в bfloat16, и мы сильно сокращаем косты на коммуникацию между девайсами.
- инициализация. Дефолтный трансформеры инициализируют из нормального распределения с mu = 0, sigma = sqrt(s/n), где n - кол-во входных юнитов, s - скейлинг фактор. Оказалось важным уменьшить s в 10 раз и инициализировать веса меньшими значениями, после этого сеть сходится существенно лучше.
- expert dropout. Для регуляризации, можно использовать стандартный dropout = 0.1 на всех слоях и будет ок. Но если на слоях с экспертами dropout учеличить до 0.4, то качество чуть бустанется.
Для сравнения эффективности Switch Transformera, в качестве бейзлайна выбрали T5 - большой dense трансформер от гугла, который выучили на том же корпусе C4.
Результаты:1. Чем больше добавляем экспертов, тем более sample efficient получается обучение + модель сходится к лучшим значениям. Также сравнили с классическим dense трансформером в виде T5. Показали что при одинаковом вычислительном бюджете (FLOPs per token) Switch transformer в 7 раз быстрее достигает качества, которое получается при обучении T5-base.
2. Switch Transformer хорошо файнтюнится на даунстрим NLP таски и бьет T5 бейзлайны почти везде - саммаризация, классификация, question answering, GLUE, SuperGLUE, etc.
3. Модель эффективно дистилируется. Взяли Switch-Base на 3.8B параметров и дистилировали в бейзлайн T5 на 223М параметров. Для эффективной дистиляции использовали два хака а) у модели-ученика все слои, за исключением экспертов, инициализировали весами из модели учителя б) при обучении дистилированно модели использовали микс из софт-лосса (для обучения ученика используем логиты от учителя), так и хард лосса(используем ground trouth лейблы). Дистилированная версия оказывается по качеству сильно лучше, чем такая же T5 выученная с нуля.
4. Улучшение на 101 языке. Показали, что есть сильное улучшение по перплексити как на high, так и на low resources языках. В качестве корпуса взяли мультиязычный C4 (mC4) и сравнили с мултиязычным T5 бейзлайном (mT5).
5. Можно учить огромные модели в триллионы параметров. Выучили Switch-C на 1.5T параметров, Switch-XXL на 395B, и бейзлайн T5-XXL на 13B. По качеству Switch-XXL оказался лучше чем Switch-C, несмотря на то, что последний в 4 раза больше. Это связано с тем, что хоть в Switch-XXL всего 64 эксперта (против 2048 в Switch-C), но каждый эксперт сильно жирнее + у всего трансформера больше слоев и аттеншн голов. Еще было замечено, что нет нестабильности в обучении Switch-C, но есть проблемы с Switch-XXL.
Если я не гугл, и у меня нет пода с 2 тысячами TPU, мне нахера все это?Если вы ресерчер с девбоксом в 4-8 GPU, все равно виден хороший буст по сравнению с аналогичным dense transformer'ом. Можно выучить такой Switch Transformer и дистилировать в более юзабельный dense transformer, получив буст в качестве.
1.1K viewsArtem R, 05:09