Аниме и генеративно-состязательная сеть: в чём связь?

МЕНЮ


Искусственный интеллект
Поиск
Регистрация на сайте
Помощь проекту

ТЕМЫ


Новости ИИРазработка ИИВнедрение ИИРабота разума и сознаниеМодель мозгаРобототехника, БПЛАТрансгуманизмОбработка текстаТеория эволюцииДополненная реальностьЖелезоКиберугрозыНаучный мирИТ индустрияРазработка ПОТеория информацииМатематикаЦифровая экономика

Авторизация



RSS


RSS новости


Генеративно-состязательная сеть, которую вы построите, создаёт персонажей из манги и аниме. Рисуйте вайфу в своё удовольствие!

Аниме и генеративно-состязательная сеть: в чём связь?

Давно хотели создать своих Аску, Код 002 или Канеки Кена? У вас появилась отличная возможность это сделать :)

Что такое генеративно-состязательная сеть?

Лучший вывод, который может генерировать нейронная сеть, похож на человеческий. Образно генеративно-состязательная сеть (GAN) может даже обмануть человека, заставив его думать, что вывод сделан им самим.

В генеративно-состязательных сетях две сети соревнуются друг с другом, что приводит к взаимным импровизациям. Генератор обманывает дискриминатор, создавая ложные входы и выдавая их за реальные. Дискриминатор сообщает, является ли ввод реальным или ложным.

Аниме и генеративно-состязательная сеть: в чём связь?

В обучении GAN есть три главных шага:

    1. Используйте генератор для создания ложных входов из случайного шума.
    2. Обучите дискриминатор на ложных и реальных входах (одновременно с объединением или поочерёдно, что предпочтительнее).
    3. Обучите всю модель: дискриминатор + генератор.

Помните, что весовые коэффициенты дискриминатора «заморожены» во время последнего шага.

Причина сочетания обоих сетей состоит в отсутствии обратной связи на выходах генератора. Единственный ориентир – если дискриминатор принимает выходы генератора.

Генеративно-состязательная сеть

Можно сказать, что они соперничают друг с другом. Генератор обучается во время схватки с «соперником», чтобы реализовать цель.

Наша сеть

Для задачи мы используем глубокую сверточную генеративно-состязательную сеть.

Вот несколько рекомендаций для таких сетей:

  1. Замените максимальные подвыборки шагами свёртки.
  2. Используйте перемещённую свёртку для повышения частоты дискретизации.
  3. Устраните полностью соединённые слои.
  4. Используйте пакетную нормализацию, кроме выходного слоя генератора и входного слоя дискриминатора.
  5. Используйте ReLU в генераторе, кроме выхода, который использует tanh.
  6. Leaky ReLU в дискриминаторе.

Детали сетапа

  • версия Keras==2.2.4
  • TensorFlow==1.8.0
  • Jupyter Notebook
  • Matplotlib и другие библиотеки типа NumPy, Pandas
  • Python==3.5.7

Набор данных

Набор данных для лиц из аниме можно собрать на тематических сайтах, скачивая картинки и вырезая лица. Код Python, который демонстрирует это.

Также доступны обработанные и обрезанные лица.

Аниме и генеративно-состязательная сеть: в чём связь?

Генератор

Он состоит из сверточных слоёв, пакетной нормализации и функции активации Leaky ReLU для повышения частоты дискретизации. Мы используем параметр шагов в сверточном слое, чтобы избежать нестабильной обучаемости GAN. Функция не будет равно нулю, если x < 0, вместо этого Leaky ReLU имеет небольшое отрицательное отклонение (0.01 и так далее).

Аниме и генеративно-состязательная сеть: в чём связь?

Код

def get_gen_normal(noise_shape):      kernel_init = 'glorot_uniform'          gen_input = Input(shape = noise_shape)             generator = Conv2DTranspose(filters = 512, kernel_size = (4,4), strides = (1,1), padding = "valid", data_format = "channels_last", kernel_initializer = kernel_init)(gen_input)      generator = BatchNormalization(momentum = 0.5)(generator)      generator = LeakyReLU(0.2)(generator)            generator = Conv2DTranspose(filters = 256, kernel_size = (4,4), strides = (2,2), padding = "same", data_format = "channels_last", kernel_initializer = kernel_init)(generator)      generator = BatchNormalization(momentum = 0.5)(generator)      generator = LeakyReLU(0.2)(generator)            generator = Conv2DTranspose(filters = 128, kernel_size = (4,4), strides = (2,2), padding = "same", data_format = "channels_last", kernel_initializer = kernel_init)(generator)      generator = BatchNormalization(momentum = 0.5)(generator)      generator = LeakyReLU(0.2)(generator)            generator = Conv2DTranspose(filters = 64, kernel_size = (4,4), strides = (2,2), padding = "same", data_format = "channels_last", kernel_initializer = kernel_init)(generator)      generator = BatchNormalization(momentum = 0.5)(generator)      generator = LeakyReLU(0.2)(generator)            generator = Conv2D(filters = 64, kernel_size = (3,3), strides = (1,1), padding = "same", data_format = "channels_last", kernel_initializer = kernel_init)(generator)      generator = BatchNormalization(momentum = 0.5)(generator)      generator = LeakyReLU(0.2)(generator)            generator = Conv2DTranspose(filters = 3, kernel_size = (4,4), strides = (2,2), padding = "same", data_format = "channels_last", kernel_initializer = kernel_init)(generator)            generator = Activation('tanh')(generator)            gen_opt = Adam(lr=0.00015, beta_1=0.5)      generator_model = Model(input = gen_input, output = generator)      generator_model.compile(loss='binary_crossentropy', optimizer=gen_opt, metrics=['accuracy'])      generator_model.summary()        return generator_model

Дискриминатор

Он также состоит из сверточных слоёв, где мы используем шаги для понижения частоты дискретизации и пакетной нормализации для стабильности.

Аниме и генеративно-состязательная сеть: в чём связь?

Код

def get_disc_normal(image_shape=(64,64,3)):      dropout_prob = 0.4      kernel_init = 'glorot_uniform'      dis_input = Input(shape = image_shape)            discriminator = Conv2D(filters = 64, kernel_size = (4,4), strides = (2,2), padding = "same", data_format = "channels_last", kernel_initializer = kernel_init)(dis_input)      discriminator = LeakyReLU(0.2)(discriminator)        discriminator = Conv2D(filters = 128, kernel_size = (4,4), strides = (2,2), padding = "same", data_format = "channels_last", kernel_initializer = kernel_init)(discriminator)      discriminator = BatchNormalization(momentum = 0.5)(discriminator)      discriminator = LeakyReLU(0.2)(discriminator)        discriminator = Conv2D(filters = 256, kernel_size = (4,4), strides = (2,2), padding = "same", data_format = "channels_last", kernel_initializer = kernel_init)(discriminator)      discriminator = BatchNormalization(momentum = 0.5)(discriminator)      discriminator = LeakyReLU(0.2)(discriminator)        discriminator = Conv2D(filters = 512, kernel_size = (4,4), strides = (2,2), padding = "same", data_format = "channels_last", kernel_initializer = kernel_init)(discriminator)      discriminator = BatchNormalization(momentum = 0.5)(discriminator)      discriminator = LeakyReLU(0.2)(discriminator)        discriminator = Flatten()(discriminator)        discriminator = Dense(1)(discriminator)        discriminator = Activation('sigmoid')(discriminator)        dis_opt = Adam(lr=0.0002, beta_1=0.5)      discriminator_model = Model(input = dis_input, output = discriminator)      discriminator_model.compile(loss='binary_crossentropy', optimizer=dis_opt, metrics=['accuracy'])      discriminator_model.summary()        return discriminator_model

Полная GAN

Чтобы генератор проверял выход, скомпилируем сеть в Keras. В этой сети входом будет случайный шум для генератора, а выходом – выход генератора «скормленный» дискриминатору. Во избежание «коллапса состязания» весовые коэффициенты дискриминатора заморожены.

Код

discriminator.trainable = False    opt = Adam(lr=0.00015, beta_1=0.5) #same as generator    gen_inp = Input(shape=noise_shape)  GAN_inp = generator(gen_inp)  GAN_opt = discriminator(GAN_inp)    gan = Model(input = gen_inp, output = GAN_opt)  gan.compile(loss = 'binary_crossentropy', optimizer = opt, metrics=['accuracy'])  gan.summary()
Аниме и генеративно-состязательная сеть: в чём связь?

Тренировка модели

Генеративно-состязательная сеть

Базовая конфигурация модели

Аниме и генеративно-состязательная сеть: в чём связь?

1. Сгенерируйте случайный нормальный шум для входа:

def gen_noise(batch_size, noise_shape):      return np.random.normal(0, 1, size=(batch_size,)+noise_shape)

2. Объедините реальные данные из набора с шумом:

data_X = np.concatenate([real_data_X, fake_data_X])  real_data_Y = np.ones(batch_size) - np.random.random_sample(batch_size)*0.2  fake_data_Y = np.random.random_sample(batch_size)*0.2    data_Y = np.concatenate([real_data_Y, fake_data_Y])

3. Подайте шум на вход:

real_data_Y = np.ones(batch_size) - np.random.random_sample(batch_size)*0.2  fake_data_Y = np.random.random_sample(batch_size)*0.2

4. Тренируем только генератор

    print("Begin step: ", tot_step)      step_begin_time = time.time()             real_data_X = sample_from_dataset(batch_size, image_shape, data_dir=data_dir)            noise = gen_noise(batch_size,noise_shape)            fake_data_X = generator.predict(noise)            if (tot_step % 100) == 0:          step_num = str(tot_step).zfill(4)          save_img_batch(fake_data_X,img_save_dir+step_num+"_image.png")

5. Тренируем только дискриминатор:

    discriminator.trainable = True      generator.trainable = False        dis_metrics_real = discriminator.train_on_batch(real_data_X,real_data_Y)       dis_metrics_fake = discriminator.train_on_batch(fake_data_X,fake_data_Y)             print("Disc: real loss: %f fake loss: %f" % (dis_metrics_real[0], dis_metrics_fake[0]))

6. Тренируем совмещённую GAN:

    generator.trainable = True      discriminator.trainable = False        GAN_X = gen_noise(batch_size,noise_shape)      GAN_Y = real_data_Y           gan_metrics = gan.train_on_batch(GAN_X,GAN_Y)      print("GAN loss: %f" % (gan_metrics[0]))            text_file = open(log_dir+"	raining_log.txt", "a")      text_file.write("Step: %d Disc: real loss: %f fake loss: %f GAN loss: %f " % (tot_step, dis_metrics_real[0], dis_metrics_fake[0],gan_metrics[0]))      text_file.close()        avg_GAN_loss.append(gan_metrics[0])                    end_time = time.time()      diff_time = int(end_time - step_begin_time)      print("Step %d completed. Time took: %s secs." % (tot_step, diff_time))

7. Сохраните экземпляры дискриминатора и генератора:

if ((tot_step+1) % 500) == 0:          print("-----------------------------------------------------------------")          print("Average Disc_fake loss: %f" % (np.mean(avg_disc_fake_loss)))           print("Average Disc_real loss: %f" % (np.mean(avg_disc_real_loss)))           print("Average GAN loss: %f" % (np.mean(avg_GAN_loss)))          print("-----------------------------------------------------------------")          discriminator.trainable = False          generator.trainable = False          # predict on fixed_noise          fixed_noise_generate = generator.predict(noise)          step_num = str(tot_step).zfill(4)          save_img_batch(fixed_noise_generate,img_save_dir+step_num+"fixed_image.png")          generator.save(save_model_dir+str(tot_step)+"_GENERATOR_weights_and_arch.hdf5")          discriminator.save(save_model_dir+str(tot_step)+"_DISCRIMINATOR_weights_and_arch.hdf5")

Результаты Манга-генератора

После 10000 шагов обучения результат выглядит круто! Смотрите сами.

Генеративно-состязательная сеть

Аниме и генеративно-состязательная сеть: в чём связь?

Более длительная тренировка с большим набором данных приведёт к лучшим результатам. (Некоторые лица получились страшными, это правда :D)

Заключение

Наверняка задача генерации лиц в стиле аниме интересна.

Но в ней ещё есть место для улучшений: лучшее обучение, модели и набор данных. Наша модель вряд ли заставит человека задуматься, реальны ли сгенерированные модели. Тем не менее, она отлично справляется с поставленной задачей. Вы можете улучшить сеть, используя персонажей во весь рост.

Исходный код проекта доступен на GitHub.

А что бы вы улучшили в этой нейронной сети?


Источник: proglib.io

Комментарии: