Dive into pyTorch |
||
МЕНЮ Искусственный интеллект Поиск Регистрация на сайте Помощь проекту ТЕМЫ Новости ИИ Искусственный интеллект Разработка ИИГолосовой помощник Городские сумасшедшие ИИ в медицине ИИ проекты Искусственные нейросети Слежка за людьми Угроза ИИ ИИ теория Внедрение ИИКомпьютерные науки Машинное обуч. (Ошибки) Машинное обучение Машинный перевод Реализация ИИ Реализация нейросетей Создание беспилотных авто Трезво про ИИ Философия ИИ Big data Работа разума и сознаниеМодель мозгаРобототехника, БПЛАТрансгуманизмОбработка текстаТеория эволюцииДополненная реальностьЖелезоКиберугрозыНаучный мирИТ индустрияРазработка ПОТеория информацииМатематикаЦифровая экономика
Генетические алгоритмы Капсульные нейросети Основы нейронных сетей Распознавание лиц Распознавание образов Распознавание речи Техническое зрение Чат-боты Авторизация |
2018-05-08 22:18 техническое зрение, искусственный интеллект, искусственный интеллект в медицине Всем привет. Меня зовут Артур Кадурин, я руковожу исследованиями в области глубокого обучения для разработки новых лекарственных препаратов в компании Insilico Medicine. В Insilico мы используем самые современные методы машинного обучения, а также сами разрабатываем и публикуем множество статей для того чтобы вылечить такие заболевания как рак или болезнь Альцгеймера, а возможно и старение как таковое. В рамках подготовки своего курса по глубокому обучению я собираюсь опубликовать серию статей на тему Состязательных(Adversarial) сетей с разбором того что же это такое и как этим пользоваться. Эта серия статей не будет очередным обзором GANов(Generative Adversarial Networks), но позволит глубже заглянуть под капот нейронных сетей и охватит более широкий спектр архитектур. Хотя GANы мы конечно тоже разберем. Для того чтобы дальше беспрепятственно обсуждать состязательные сети я решил сначала сделать небольшое введение в pyTorch. Хочу сразу заметить, что это не введение в нейронные сети, поэтому я исхожу из того, что вы уже знаете такие слова как "слой", "батч", "бэкпроп" и т.д. Помимо базовых знаний о нейросетях, вам, конечно, понадобится понимание языка python. Для того чтобы было удобно пользоваться pyTorch я подготовил докер-контейнер с jupyter'ом и кодом в ноутбуках. Если вы захотите запускать обучение на видеокарте, то для видеокарт от NVIDIA вам потребуется nvidia-docker, думаю с этой частью у большинства из вас проблем не будет, поэтому остальное я оставляю вам. Все необходимое для этого поста доступно в моем репозитории spoilt333/adversarial с тегом intro на Docker Hub, или в моем репозитории на GitHub. После установки докера запустить контейнер можно например с помощью такой команды:
В контейнере автоматически запустится сервер jupyter'а, который будет доступен по http://127.0.0.1:8765 с паролем "password"(без кавычек). Если вы не хотите запускать чужой контейнер у себя на машине(правильно!), то собрать свой такой же, предварительно проверив что там все ок, можно из докерфайла который есть в репозитории на GitHub. Если у вас все запустилось и вы смогли подключиться к jupyter, то давайте перейдем к тому что же из себя представляет pyTorch. pyTorch — это большой фреймворк позволяющий создавать динамические графы вычислений и автоматически вычислять градиенты по этим графам. Для машинного обучения это как раз то что нужно. Но, помимо самой возможности обучать модели, pyTorch это еще и огромная библиотека включающая датасеты, готовые модели, современные слои и комьюнити вокруг всего этого. В Deep Learning, довольно продолжительное время, было практически стандартом тестировать все новые модели на задаче распознавания рукописных цифр. Датасет MNIST представляет из себя 70.000 размеченных рукописных цифр примерно поровну распределенных между классами. Он сразу же разбит на тренировочное и тестовое множества для того чтобы обеспечить одинаковые условия всем кто тестируется на этом датасете. В pyTorch, естественно, для него есть простые интерфейсы. Несмотря на то что сравнивать между собой state-of-the-art модели на этом датасете уже не имеет большого смысла, для демонстрационных целей он нам подойдет идеально. Каждый пример в MNIST представляет из себя изображение размером 28х28 пикселей в оттенках серого. И, как нетрудно заметить, далеко не все цифры легко может "распознать" даже человек. В ноутбуке mnist.ipynb вы можете посмотреть на пример загрузки и отображения датасета, а несколько полезных функций вынесены в файл utils.py. Но давайте перейдем к основному "блюду". В ноутбуке mnist-basic.ipynb реализована двухслойная полносвязная нейронная сеть решающая задачу классификации. Один из способов сделать нейронную сеть с помощью pyTorch — это наследоваться от класса nn.Module и реализовать свои функции инициализации и forward
Внутри функции __init__ мы объявляем слои будущей нейронной сети. В нашем случае это линейные слои nn.Linear которые имеют вид W'x+b, где W — матрица весов размером (input, output) и b — вектор смещения размером output. Эти самые веса и будут "обучаться" в процессе тренировки нейронной сети.
Метод forward используется непосредственно для преобразования входных данных с помощью заданной нейройнной сети в ее выходы. Для простоты примера мы будем работать с примерами из MNIST не как с изображениями, а как с векторами каждая размерность которых соответствует одному из пикселей. Функция view() это аналог numpy.reshape(), она переиндексирует тензор с данными заданным образом. "-1" в качестве первого аргумента функции означает, что количество элементов в первой размерности будет вычислено автоматически. Если исходный тензор x имеет размерность (N, 28, 28), то после
его размерность станет равна (N, 784).
Применение слоев к данным в pyTorch реализовано максимально просто, вы можете "вызвать" слой передав ему в качестве аргумента батч данных и получить на выходе результат преобразования. Аналогичным образом устроены и функции активации. В данном случае я использую relu, так как это наиболее популярная функция активации в задачах компьютерного зрения, однако вы легко можете поэкспериментировать с другими реализованными в pyTorch функциями, благо их там достаточно.
Так как мы решаем задачу классификации на 10 классов, то и выход нашей сети имеет размерность 10. В качестве функции активации на выходе сети мы используем softmax. Теперь значения которые возвращает функция forward можно интерпретировать как вероятности того что входной пример принадлежит к соответствующим классам.
Теперь мы можем создать экземпляр нашей сети и выбрать функцию оптимизации. Для того чтобы получился симпатичный график обучения я выбрал обыкновенный стохастический градиентный спуск, но в pyTorch, конечно же, реализованы и более продвинутые методы. Вы можете попробовать например RMSProp или Adam.
Функция train содержит основной цикл обучения в котором мы итерируемся по батчам из тренировочного множества. data — это примеры, а target — соответствующие метки. Для того чтобы pyTorch смог корректно считать градиенты мы оборачиваем данные в класс Variable, и в начале каждой итерации обнуляем текущее значение градиентов:
Обработка данных всей сетью в pyTorch ничем не отличается от применения отдельного слоя. За вызовом model(data) скрыт вызов функции forward, поэтому в output попадают выходы сети. Теперь остается только посчитать значение функции ошибки и сделать шаг обратного распространения:
На самом деле, при вызове loss.backward() веса сети еще не обновляются, но для всех весов использовавшихся при вычислении ошибки pyTorch считает градиенты используя построенный граф вычислений. Для того чтобы обновить веса мы вызываем optimizer.step(), который опираясь на свои параметры(у нас это learning rate) обновляет веса. После 20 эпох обучения наша сеть угадывает цифры с точностью 91%, что, конечно, далеко от SOTA результатов, однако, весьма неплохо для 5 минут программирования. Вот пример из тестового множества с предсказанными ответами
В следующих постах я постараюсь рассказать о состязательных сетях в таком же стиле с примерами кода и подготовленными докер-контейнерами, в частности я планирую коснуться таких тем как domain adaptation, style transfer, generative adversarial networks и разобрать несколько наиболее важных статей в этой области. Источник: habr.com Комментарии: |
|