Интерактивное использование обученной модели tensorflow
Я играл с моделью тензорного потока LSTM (суммирование предложений) и довел ее до того, что она делает довольно хорошую работу. Но когда я пытаюсь импортировать сохраненную модель и использовать ее в интерактивном режиме, она не дает никаких результатов.
Мой исходный код использовал контрольные точки, но я переключился на SavedModelBuilder (), потому что думал, что с ним будет легче работать.
import tensorflow as tf from Seq2SeqModel import Seq2SeqModel import utils from time import gmtime, strftime import sys, os ############################################################################### # read in our input and output files ############################################################################### # for this example we will expect our input and output files to have the same number of lines. # input[x] should be translated into output[x] ############################################################################### print("="*60) DATADIR = "../DATASETS/sentence-compression-master/data/" INPUT_DATA = DATADIR + "long_sentences.txt" OUTPUT_DATA = DATADIR + "short_sentences.txt" # read the files in as 'utf-8-sig' to ignore a leading Byte-Order-Encoding bit if present with open(INPUT_DATA, 'r', encoding='utf-8') as f: in_sentences = f.readlines() with open(OUTPUT_DATA, 'r', encoding='utf-8') as f: out_sentences = f.readlines() in_chars = [] TEMP_LINE = " abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ[]()-_,.:;!?0123456789$+'\"" for char in TEMP_LINE: in_chars.append(char) ############################################################################### # create our encoding/decoding scheme # note since we are not doing any special encoding (make lower case) or gathering the # corpus (bucket of words, or characters found in text) we could skip this part. ############################################################################### # our encoder and decoder will be stored as dictionaries converting between Chars and Ints: char_to_int = dict() int_to_char = dict() for i,char in enumerate(in_chars): char_to_int[char] = i int_to_char[i] = char num_in_chars = len(in_chars) max_in_chars_per_sample = max([len(sample) for sample in in_sentences]) max_out_chars_per_sample = max([len(sample) for sample in out_sentences]) num_samples = len(in_sentences) print(strftime("%Y-%m-%d %H:%M:%S", gmtime()) + " Create our Training Data") X_train = [] y_train = [] for sent in in_sentences: id_sent = [] for mychar in sent: if mychar in in_chars: vocab_id = char_to_int[mychar] id_sent += [vocab_id] X_train += [id_sent] for sent in out_sentences: id_sent = [] for mychar in sent: if mychar in in_chars: vocab_id = char_to_int[mychar] id_sent += [vocab_id] y_train += [id_sent] #tf.set_random_seed(1) # gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=0.8) ############################################################################### # train the model, X is our input, Y is our output ############################################################################### print(strftime("%Y-%m-%d %H:%M:%S", gmtime()) + " Train our Model") step = 0 batch_size = 32 max_batches = int(len(X_train) / batch_size) batches_in_epoch = 1 epoch_be_saved = 1 my_beam_width = 3 # if checkpoint found CHKPT_FOUND = False for line in open("../models/checkpoint"): if len(line) > 1: data = line.split('"') checkpoint = "../models/" + data[1] data2 = data[1].split("-") step = int(data2[1]) CHKPT_FOUND = True print("Selecting previous checkpoint ../models/nmt.ckpt-" + str(step) + ".index") g = tf.Graph() with g.as_default(): model = Seq2SeqModel(encoder_num_units = 512, decoder_num_units = 512, embedding_size = 512, num_layers = 2, vocab_size = num_in_chars, batch_size = batch_size, bidirectional = False, attention = True, beam_search = True, beam_width = my_beam_width, mode = "Train") print(strftime("%Y-%m-%d %H:%M:%S", gmtime()) + " model constructed.") builder = tf.saved_model.builder.SavedModelBuilder('./SavedModel/') with tf.Session(config=tf.ConfigProto()) as sess: sess.run(tf.global_variables_initializer()) saver = tf.train.Saver() summary_writer = tf.summary.FileWriter('../log', graph = sess.graph) if CHKPT_FOUND == True: print('loading previous checkpoint [' + checkpoint + ']') saver.restore(sess, checkpoint) builder.add_meta_graph_and_variables(sess, [tf.saved_model.tag_constants.SERVING], signature_def_map=None, assets_collection=None) builder.save() print('start training.') for _epoch in range(1, batches_in_epoch + 1): for _batch in range(max_batches + 1): if step % 5 == 0: print(strftime("%Y-%m-%d %H:%M:%S", gmtime()) + " Processing Batch " + str(_batch) + " of " + str(max_batches)) X, y = utils.input_generator(X_train, y_train, batch_size) feed_dict = model.make_train_inputs(x = X, y = y) _, l, train_sentences, summary_str = sess.run([model.train_op, model.loss, model.decoder_predictions_train, model.summary_op], feed_dict) summary_writer.add_summary(summary_str, _epoch * _batch) if step == 0 or step % 25 == 0: print(strftime("%Y-%m-%d %H:%M:%S", gmtime()) + " Step {}".format(step)) print(' minibatch loss: {}'.format(sess.run(model.loss, feed_dict))) for i in range(1): print('train logits:') train_sentence = '' for mychar in train_sentences[i]: train_sentence += str(int_to_char[mychar]) print(train_sentence) # we could also print out X and y sentences for better context print(' ') if step % 100 == 0: saver.save(sess, '../models/' + 'nmt.ckpt', global_step = step) print(strftime("%Y-%m-%d %H:%M:%S", gmtime()) + " model saved at step = " + str(step)) step += 1 print(strftime("%Y-%m-%d %H:%M:%S", gmtime()) + " epoch finished") if _epoch % epoch_be_saved == 0: tf.saved_model.simple_save(sess, 'models/0', inputs, outputs) saver.save(sess, '../models/' + 'nmt.ckpt', global_step = step) print(strftime("%Y-%m-%d %H:%M:%S", gmtime()) + " model saved at step = " + str(step)) print(strftime("%Y-%m-%d %H:%M:%S", gmtime()) + " Finished Training")
Затем я смог загрузить свою модель в другой скрипт, но когда я пытаюсь что-то сделать с ней, все, что я получаю, - это пустые выходные массивы.
import tensorflow as tf from tensorflow.python.saved_model import tag_constants from Seq2SeqModel import Seq2SeqModel import utils from time import gmtime, strftime import sys, os ############################################################################### # read in our input and output files ############################################################################### in_sentences = [] out_sentences = [] in_chars = [] TEMP_LINE = " abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ[]()-_,.:;!?0123456789$+'\"" for char in TEMP_LINE: in_chars.append(char) ############################################################################### # create our encoding/decoding scheme # note since we are not doing any special encoding (make lower case) or gathering the # corpus (bucket of words, or characters found in text) we could skip this part. ############################################################################### # our encoder and decoder will be stored as dictionaries converting between Chars and Ints: char_to_int = dict() int_to_char = dict() for i,char in enumerate(in_chars): char_to_int[char] = i int_to_char[i] = char num_in_chars = len(in_chars) ############################################################################### # train the model, X is our input, Y is our output ############################################################################### checkpoint = 'models/0' step = 0 batch_size = 32 max_batches = 1 batches_in_epoch = 1 epoch_be_saved = 1 my_beam_width = 3 temperature = 1 top_k = 0 g = tf.Graph() with g.as_default(): model = Seq2SeqModel(encoder_num_units = 512, decoder_num_units = 512, embedding_size = 512, num_layers = 2, vocab_size = num_in_chars, batch_size = batch_size, bidirectional = False, attention = True, beam_search = True, beam_width = my_beam_width, mode = "Infer" ) print(strftime("%Y-%m-%d %H:%M:%S", gmtime()) + " model constructed.") gvars = g.get_collection(tf.GraphKeys.GLOBAL_VARIABLES) assign_ops = [g.get_operation_by_name(v.op.name + "/Assign") for v in gvars] init_values = [assign_op.inputs[1] for assign_op in assign_ops] #hparams = model.default_hparams() #length = hparams.n_ctx // 2 with tf.Session(config=tf.ConfigProto()) as sess: tf.saved_model.loader.load(sess, [tf.saved_model.tag_constants.SERVING], './SavedModel/') #sess.run(tf.global_variables_initializer()) print(strftime("%Y-%m-%d %H:%M:%S", gmtime()) + " model loaded.") context_tokens = [] sent = "World renowned Count Olaf performed in a play and got a standing ovation from the crowd." for mychar in sent: if mychar in in_chars: context_tokens.append(char_to_int[mychar]) generated = 0 print(sent) output = [] feed_dict = model.make_infer_inputs(x = [context_tokens]) print(feed_dict) out = sess.run(output, feed_dict) print(output) print("xxxxxxxxxxxxxxxxxx") print(out) out_string = "" for i in range(len(output)): out_string += int_to_char[out[i]] print(out_string) print("=" * 40)
Это дает мне следующий результат:
World renowned Count Olaf performed in a play and got a standing ovation from the crowd. {<tf.Tensor 'encoder_inputs:0' shape=(?, ?) dtype=int32>: array([[49, 15, 18, 12, 4, 0, 18, 5, 14, 15, 23, 14, 5, 4, 0, 29, 15, 21, 14, 20, 0, 41, 12, 1, 6, 0, 16, 5, 18, 6, 15, 18, 13, 5, 4, 0, 9, 14, 0, 1, 0, 16, 12, 1, 25, 0, 1, 14, 4, 0, 7, 15, 20, 0, 1, 0, 19, 20, 1, 14, 4, 9, 14, 7, 0, 15, 22, 1, 20, 9, 15, 14, 0, 6, 18, 15, 13, 0, 20, 8, 5, 0, 3, 18, 15, 23, 4, 60]]), <tf.Tensor 'encoder_inputs_length:0' shape=(?,) dtype=int32>: [88]} [] xxxxxxxxxxxxxxxxxx []
Мой feed_dict выглядит нормально, и я, вероятно, делаю какую-то очевидную ошибку, но я не могу ее найти.
Как я могу взять обученную модель, заморозить ее на месте, а затем использовать ее для обработки отдельных предложений?
Я понимаю, что многое могу сделать, чтобы очистить код, но, к сожалению, поскольку я пробовал разные подходы (и все еще ничего не получал), код стал немного неряшливым. Пожалуйста, помогите, прежде чем он превратится в нечитаемый беспорядок :-)
При необходимости я могу удалить часть кода, чтобы упростить вещи, я просто беспокоился, что кто-то может попросить полный код для оценки.
Что я уже пробовал:
Я пробовал следовать нескольким учебным пособиям в интернете, а также просматривать игрушечный код для SavedModelBuilder и работать с контрольными точками, но безрезультатно.