keras 绘制网络结构

先下载graphviz https://www.graphviz.org/download/ ,可以下载zip绿色版的

解压之后,将bin目录添加到PATH变量中

再执行以下代码

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
"""
导包
"""
import matplotlib.pyplot as plt
from keras.layers import Dense, Dropout, Convolution2D, MaxPool2D, Flatten
from keras.models import Sequential
from keras.utils.vis_utils import plot_model

"""
构建网络
"""

model = Sequential([
Convolution2D(input_shape=(28, 28, 1), filters=32, kernel_size=1,
strides=1, padding='same', activation='relu'),
MaxPool2D(pool_size=2, strides=2, padding='same'),
Convolution2D(filters=64, kernel_size=5, strides=1, padding='same', activation='relu'),
MaxPool2D(pool_size=2, strides=2, padding='same'),
Flatten(),
Dense(units=1024, activation='relu'),
Dropout(0.5),
Dense(units=10, activation='softmax'),
])

"""
绘制网络
"""

plot_model(model, to_file="model.png", show_shapes=True, show_layer_names="False", rankdir="TB")
img = plt.imread("model.png")
plt.imshow(img)
plt.show()