Tensorflow cnn网络 vgg可视化

时间:2022-01-13 23:53:21


1 安装 tf_cnnvis   https://github.com/InFoCusp/tf_cnnvis

使用tensorboard查看结果,(如果不能运行ipynb,可以直接将内容拷贝出成.py 运行)

3 vgg19 的可视化

Tensorflow cnn网络 vgg可视化



# -*- coding: utf-8 -*-
import os
import sys
sys.path.append(os.path.dirname(os.getcwd()))
import time
import copy
import numpy as np
import tensorflow as tf
from subprocess import call
from scipy.misc import imread, imresize
from tf_cnnvis import get_visualization
import vgg
import scipy.misc

def get_img(src, img_size=False):
   img = scipy.misc.imread(src, mode='RGB') # misc.imresize(, (256, 256, 3))
   if not (len(img.shape) == 3 and img.shape[2] == 3):
       img = np.dstack((img,img,img))
   if img_size != False:
       img = scipy.misc.imresize(img, img_size)
   return img

# tensorflow model implementation (Alexnet convolution)
X = tf.placeholder(tf.float32, shape = [None, 224, 224, 3]) # placeholder for input images

image = tf.reshape(X, shape=[224, 224, 3])
image_pre = vgg.preprocess(image)
image_pre = tf.reshape(image_pre, shape=[-1, 224, 224, 3])
image_pre = tf.to_float(image_pre)

# Test pretrained model
net = vgg.net('../../softmax/vgg_classfication/data/imagenet-vgg-verydeep-19.mat', image_pre)

im = get_img(os.path.join("./sample_images", "images.jpg"), (224,224,3)).astype(np.float32)
im = np.expand_dims(im, axis = 0)
# im, filenames = load_images(path = "./sample_images")

layers = ["r", "p", "c"]
total_time = 0

start = time.time()
get_visualization(graph_or_path = tf.get_default_graph(), value_feed_dict = {X : im}, layers=layers, n = 8)
start = time.time() - start
print("Total Time = %f" % (start))