Аниме и генеративно-состязательная сеть: в чём связь? |
||
МЕНЮ Искусственный интеллект Поиск Регистрация на сайте Помощь проекту ТЕМЫ Новости ИИ Искусственный интеллект Разработка ИИГолосовой помощник Городские сумасшедшие ИИ в медицине ИИ проекты Искусственные нейросети Слежка за людьми Угроза ИИ ИИ теория Внедрение ИИКомпьютерные науки Машинное обуч. (Ошибки) Машинное обучение Машинный перевод Реализация ИИ Реализация нейросетей Создание беспилотных авто Трезво про ИИ Философия ИИ Big data Работа разума и сознаниеМодель мозгаРобототехника, БПЛАТрансгуманизмОбработка текстаТеория эволюцииДополненная реальностьЖелезоКиберугрозыНаучный мирИТ индустрияРазработка ПОТеория информацииМатематикаЦифровая экономика
Генетические алгоритмы Капсульные нейросети Основы нейронных сетей Распознавание лиц Распознавание образов Распознавание речи Техническое зрение Чат-боты Авторизация |
2019-09-23 00:29 Генеративно-состязательная сеть, которую вы построите, создаёт персонажей из манги и аниме. Рисуйте вайфу в своё удовольствие! Давно хотели создать своих Аску, Код 002 или Канеки Кена? У вас появилась отличная возможность это сделать :) Что такое генеративно-состязательная сеть? Лучший вывод, который может генерировать нейронная сеть, похож на человеческий. Образно генеративно-состязательная сеть (GAN) может даже обмануть человека, заставив его думать, что вывод сделан им самим. В генеративно-состязательных сетях две сети соревнуются друг с другом, что приводит к взаимным импровизациям. Генератор обманывает дискриминатор, создавая ложные входы и выдавая их за реальные. Дискриминатор сообщает, является ли ввод реальным или ложным. В обучении GAN есть три главных шага:
Помните, что весовые коэффициенты дискриминатора «заморожены» во время последнего шага. Причина сочетания обоих сетей состоит в отсутствии обратной связи на выходах генератора. Единственный ориентир – если дискриминатор принимает выходы генератора. Можно сказать, что они соперничают друг с другом. Генератор обучается во время схватки с «соперником», чтобы реализовать цель. Наша сеть Для задачи мы используем глубокую сверточную генеративно-состязательную сеть. Вот несколько рекомендаций для таких сетей:
Детали сетапа
Набор данных Набор данных для лиц из аниме можно собрать на тематических сайтах, скачивая картинки и вырезая лица. Код Python, который демонстрирует это. Также доступны обработанные и обрезанные лица. Генератор Он состоит из сверточных слоёв, пакетной нормализации и функции активации Leaky ReLU для повышения частоты дискретизации. Мы используем параметр шагов в сверточном слое, чтобы избежать нестабильной обучаемости GAN. Функция не будет равно нулю, если x < 0, вместо этого Leaky ReLU имеет небольшое отрицательное отклонение (0.01 и так далее). Код def get_gen_normal(noise_shape): kernel_init = 'glorot_uniform' gen_input = Input(shape = noise_shape) generator = Conv2DTranspose(filters = 512, kernel_size = (4,4), strides = (1,1), padding = "valid", data_format = "channels_last", kernel_initializer = kernel_init)(gen_input) generator = BatchNormalization(momentum = 0.5)(generator) generator = LeakyReLU(0.2)(generator) generator = Conv2DTranspose(filters = 256, kernel_size = (4,4), strides = (2,2), padding = "same", data_format = "channels_last", kernel_initializer = kernel_init)(generator) generator = BatchNormalization(momentum = 0.5)(generator) generator = LeakyReLU(0.2)(generator) generator = Conv2DTranspose(filters = 128, kernel_size = (4,4), strides = (2,2), padding = "same", data_format = "channels_last", kernel_initializer = kernel_init)(generator) generator = BatchNormalization(momentum = 0.5)(generator) generator = LeakyReLU(0.2)(generator) generator = Conv2DTranspose(filters = 64, kernel_size = (4,4), strides = (2,2), padding = "same", data_format = "channels_last", kernel_initializer = kernel_init)(generator) generator = BatchNormalization(momentum = 0.5)(generator) generator = LeakyReLU(0.2)(generator) generator = Conv2D(filters = 64, kernel_size = (3,3), strides = (1,1), padding = "same", data_format = "channels_last", kernel_initializer = kernel_init)(generator) generator = BatchNormalization(momentum = 0.5)(generator) generator = LeakyReLU(0.2)(generator) generator = Conv2DTranspose(filters = 3, kernel_size = (4,4), strides = (2,2), padding = "same", data_format = "channels_last", kernel_initializer = kernel_init)(generator) generator = Activation('tanh')(generator) gen_opt = Adam(lr=0.00015, beta_1=0.5) generator_model = Model(input = gen_input, output = generator) generator_model.compile(loss='binary_crossentropy', optimizer=gen_opt, metrics=['accuracy']) generator_model.summary() return generator_model Дискриминатор Он также состоит из сверточных слоёв, где мы используем шаги для понижения частоты дискретизации и пакетной нормализации для стабильности. Код def get_disc_normal(image_shape=(64,64,3)): dropout_prob = 0.4 kernel_init = 'glorot_uniform' dis_input = Input(shape = image_shape) discriminator = Conv2D(filters = 64, kernel_size = (4,4), strides = (2,2), padding = "same", data_format = "channels_last", kernel_initializer = kernel_init)(dis_input) discriminator = LeakyReLU(0.2)(discriminator) discriminator = Conv2D(filters = 128, kernel_size = (4,4), strides = (2,2), padding = "same", data_format = "channels_last", kernel_initializer = kernel_init)(discriminator) discriminator = BatchNormalization(momentum = 0.5)(discriminator) discriminator = LeakyReLU(0.2)(discriminator) discriminator = Conv2D(filters = 256, kernel_size = (4,4), strides = (2,2), padding = "same", data_format = "channels_last", kernel_initializer = kernel_init)(discriminator) discriminator = BatchNormalization(momentum = 0.5)(discriminator) discriminator = LeakyReLU(0.2)(discriminator) discriminator = Conv2D(filters = 512, kernel_size = (4,4), strides = (2,2), padding = "same", data_format = "channels_last", kernel_initializer = kernel_init)(discriminator) discriminator = BatchNormalization(momentum = 0.5)(discriminator) discriminator = LeakyReLU(0.2)(discriminator) discriminator = Flatten()(discriminator) discriminator = Dense(1)(discriminator) discriminator = Activation('sigmoid')(discriminator) dis_opt = Adam(lr=0.0002, beta_1=0.5) discriminator_model = Model(input = dis_input, output = discriminator) discriminator_model.compile(loss='binary_crossentropy', optimizer=dis_opt, metrics=['accuracy']) discriminator_model.summary() return discriminator_model Полная GAN Чтобы генератор проверял выход, скомпилируем сеть в Keras. В этой сети входом будет случайный шум для генератора, а выходом – выход генератора «скормленный» дискриминатору. Во избежание «коллапса состязания» весовые коэффициенты дискриминатора заморожены. Код discriminator.trainable = False opt = Adam(lr=0.00015, beta_1=0.5) #same as generator gen_inp = Input(shape=noise_shape) GAN_inp = generator(gen_inp) GAN_opt = discriminator(GAN_inp) gan = Model(input = gen_inp, output = GAN_opt) gan.compile(loss = 'binary_crossentropy', optimizer = opt, metrics=['accuracy']) gan.summary() Тренировка модели Базовая конфигурация модели 1. Сгенерируйте случайный нормальный шум для входа: def gen_noise(batch_size, noise_shape): return np.random.normal(0, 1, size=(batch_size,)+noise_shape) 2. Объедините реальные данные из набора с шумом: data_X = np.concatenate([real_data_X, fake_data_X]) real_data_Y = np.ones(batch_size) - np.random.random_sample(batch_size)*0.2 fake_data_Y = np.random.random_sample(batch_size)*0.2 data_Y = np.concatenate([real_data_Y, fake_data_Y]) 3. Подайте шум на вход: real_data_Y = np.ones(batch_size) - np.random.random_sample(batch_size)*0.2 fake_data_Y = np.random.random_sample(batch_size)*0.2 4. Тренируем только генератор print("Begin step: ", tot_step) step_begin_time = time.time() real_data_X = sample_from_dataset(batch_size, image_shape, data_dir=data_dir) noise = gen_noise(batch_size,noise_shape) fake_data_X = generator.predict(noise) if (tot_step % 100) == 0: step_num = str(tot_step).zfill(4) save_img_batch(fake_data_X,img_save_dir+step_num+"_image.png") 5. Тренируем только дискриминатор: discriminator.trainable = True generator.trainable = False dis_metrics_real = discriminator.train_on_batch(real_data_X,real_data_Y) dis_metrics_fake = discriminator.train_on_batch(fake_data_X,fake_data_Y) print("Disc: real loss: %f fake loss: %f" % (dis_metrics_real[0], dis_metrics_fake[0])) 6. Тренируем совмещённую GAN: generator.trainable = True discriminator.trainable = False GAN_X = gen_noise(batch_size,noise_shape) GAN_Y = real_data_Y gan_metrics = gan.train_on_batch(GAN_X,GAN_Y) print("GAN loss: %f" % (gan_metrics[0])) text_file = open(log_dir+" raining_log.txt", "a") text_file.write("Step: %d Disc: real loss: %f fake loss: %f GAN loss: %f " % (tot_step, dis_metrics_real[0], dis_metrics_fake[0],gan_metrics[0])) text_file.close() avg_GAN_loss.append(gan_metrics[0]) end_time = time.time() diff_time = int(end_time - step_begin_time) print("Step %d completed. Time took: %s secs." % (tot_step, diff_time)) 7. Сохраните экземпляры дискриминатора и генератора: if ((tot_step+1) % 500) == 0: print("-----------------------------------------------------------------") print("Average Disc_fake loss: %f" % (np.mean(avg_disc_fake_loss))) print("Average Disc_real loss: %f" % (np.mean(avg_disc_real_loss))) print("Average GAN loss: %f" % (np.mean(avg_GAN_loss))) print("-----------------------------------------------------------------") discriminator.trainable = False generator.trainable = False # predict on fixed_noise fixed_noise_generate = generator.predict(noise) step_num = str(tot_step).zfill(4) save_img_batch(fixed_noise_generate,img_save_dir+step_num+"fixed_image.png") generator.save(save_model_dir+str(tot_step)+"_GENERATOR_weights_and_arch.hdf5") discriminator.save(save_model_dir+str(tot_step)+"_DISCRIMINATOR_weights_and_arch.hdf5") Результаты Манга-генератора После 10000 шагов обучения результат выглядит круто! Смотрите сами.
Более длительная тренировка с большим набором данных приведёт к лучшим результатам. (Некоторые лица получились страшными, это правда :D) Заключение Наверняка задача генерации лиц в стиле аниме интересна. Но в ней ещё есть место для улучшений: лучшее обучение, модели и набор данных. Наша модель вряд ли заставит человека задуматься, реальны ли сгенерированные модели. Тем не менее, она отлично справляется с поставленной задачей. Вы можете улучшить сеть, используя персонажей во весь рост. Исходный код проекта доступен на GitHub. А что бы вы улучшили в этой нейронной сети? Источник: proglib.io Комментарии: |
|