DL之DNN:DNN优化技术之利用MultiLayerNetExtend算法(BN层使用/不使用+权重初始值不同)对Mnist数据集训练评估学习过程
DL之DNN:DNN优化技术之利用MultiLayerNetExtend算法(BN层使用/不使用+权重初始值不同)对Mnist数据集训练评估学习过程
输出结果
更多输出详见最后
设计思路
核心代码
(x_train, t_train), (x_test, t_test) = load_mnist(normalize=True)
x_train = x_train[:1000]
t_train = t_train[:1000]
max_epochs = 20
train_size = x_train.shape[0]
batch_size = 100
learning_rate = 0.01
bn_network = MultiLayerNetExtend(input_size=784, hidden_size_list=[100, 100, 100, 100, 100], output_size=10,
weight_init_std=weight_init_std, use_batchnorm=True)
network = MultiLayerNetExtend(input_size=784, hidden_size_list=[100, 100, 100, 100, 100], output_size=10,
weight_init_std=weight_init_std)
optimizer = SGD(lr=learning_rate)
train_acc_list = []
bn_train_acc_list = []
iter_per_epoch = max(train_size / batch_size, 1)
for i in range(1000000000):
#定义x_batch、t_batch
batch_mask = np.random.choice(train_size, batch_size)
x_batch = x_train[batch_mask]
t_batch = t_train[batch_mask]
for _network in (bn_network, network):
grads = _network.gradient(x_batch, t_batch)
optimizer.update(_network.params, grads)
if i % iter_per_epoch == 0:
train_acc = network.accuracy(x_train, t_train)
bn_train_acc = bn_network.accuracy(x_train, t_train)
train_acc_list.append(train_acc)
bn_train_acc_list.append(bn_train_acc)
print("epoch:" + str(epoch_cnt) + " | " + str(train_acc) + " - " + str(bn_train_acc))
epoch_cnt += 1
if epoch_cnt >= max_epochs:
break
return train_acc_list, bn_train_acc_list
更多输出
============== 1/16 ==============
epoch:0 | 0.093 - 0.085
epoch:1 | 0.117 - 0.08
epoch:2 | 0.117 - 0.081
epoch:3 | 0.117 - 0.1
epoch:4 | 0.117 - 0.125
epoch:5 | 0.117 - 0.143
epoch:6 | 0.117 - 0.163
epoch:7 | 0.117 - 0.191
epoch:8 | 0.117 - 0.213
epoch:9 | 0.117 - 0.236
epoch:10 | 0.117 - 0.258
epoch:11 | 0.117 - 0.268
epoch:12 | 0.117 - 0.28
epoch:13 | 0.117 - 0.297
epoch:14 | 0.117 - 0.31
epoch:15 | 0.117 - 0.322
epoch:16 | 0.117 - 0.335
epoch:17 | 0.117 - 0.36
epoch:18 | 0.116 - 0.378
epoch:19 | 0.117 - 0.391
============== 2/16 ==============
epoch:0 | 0.087 - 0.099
epoch:1 | 0.097 - 0.108
epoch:2 | 0.097 - 0.151
epoch:3 | 0.097 - 0.185
epoch:4 | 0.097 - 0.216
epoch:5 | 0.097 - 0.226
epoch:6 | 0.097 - 0.243
epoch:7 | 0.097 - 0.281
epoch:8 | 0.097 - 0.306
epoch:9 | 0.097 - 0.323
epoch:10 | 0.097 - 0.344
epoch:11 | 0.097 - 0.364
epoch:12 | 0.097 - 0.38
epoch:13 | 0.097 - 0.394
epoch:14 | 0.097 - 0.402
epoch:15 | 0.097 - 0.415
epoch:16 | 0.097 - 0.441
epoch:17 | 0.097 - 0.454
epoch:18 | 0.097 - 0.464
epoch:19 | 0.097 - 0.48
============== 3/16 ==============
epoch:0 | 0.104 - 0.108
epoch:1 | 0.364 - 0.111
epoch:2 | 0.499 - 0.121
epoch:3 | 0.587 - 0.153
epoch:4 | 0.679 - 0.166
epoch:5 | 0.735 - 0.197
epoch:6 | 0.777 - 0.229
epoch:7 | 0.817 - 0.259
epoch:8 | 0.855 - 0.297
epoch:9 | 0.882 - 0.333
epoch:10 | 0.906 - 0.364
epoch:11 | 0.926 - 0.396
epoch:12 | 0.927 - 0.421
epoch:13 | 0.943 - 0.436
epoch:14 | 0.954 - 0.457
epoch:15 | 0.957 - 0.491
epoch:16 | 0.962 - 0.51
epoch:17 | 0.971 - 0.531
epoch:18 | 0.978 - 0.551
epoch:19 | 0.983 - 0.561
============== 4/16 ==============
epoch:0 | 0.116 - 0.098
epoch:1 | 0.246 - 0.139
epoch:2 | 0.395 - 0.188
epoch:3 | 0.48 - 0.25
epoch:4 | 0.532 - 0.32
epoch:5 | 0.598 - 0.378
epoch:6 | 0.64 - 0.445
epoch:7 | 0.662 - 0.496
epoch:8 | 0.702 - 0.55
epoch:9 | 0.724 - 0.583
epoch:10 | 0.759 - 0.616
epoch:11 | 0.784 - 0.645
epoch:12 | 0.794 - 0.679
epoch:13 | 0.815 - 0.686
epoch:14 | 0.832 - 0.716
epoch:15 | 0.842 - 0.744
epoch:16 | 0.858 - 0.754
epoch:17 | 0.866 - 0.774
epoch:18 | 0.878 - 0.776
epoch:19 | 0.882 - 0.794
============== 5/16 ==============
epoch:0 | 0.11 - 0.122
epoch:1 | 0.117 - 0.137
epoch:2 | 0.119 - 0.259
epoch:3 | 0.123 - 0.363
epoch:4 | 0.135 - 0.45
epoch:5 | 0.137 - 0.511
epoch:6 | 0.145 - 0.572
epoch:7 | 0.148 - 0.617
epoch:8 | 0.157 - 0.662
epoch:9 | 0.163 - 0.696
epoch:10 | 0.18 - 0.721
epoch:11 | 0.197 - 0.74
epoch:12 | 0.219 - 0.758
epoch:13 | 0.247 - 0.781
epoch:14 | 0.278 - 0.796
epoch:15 | 0.298 - 0.819
epoch:16 | 0.311 - 0.836
epoch:17 | 0.322 - 0.849
epoch:18 | 0.33 - 0.862
epoch:19 | 0.318 - 0.874
============== 6/16 ==============
epoch:0 | 0.117 - 0.144
epoch:1 | 0.101 - 0.22
epoch:2 | 0.112 - 0.495
epoch:3 | 0.14 - 0.649
epoch:4 | 0.133 - 0.724
epoch:5 | 0.117 - 0.77
epoch:6 | 0.116 - 0.806
epoch:7 | 0.116 - 0.816
epoch:8 | 0.116 - 0.836
epoch:9 | 0.116 - 0.847
epoch:10 | 0.116 - 0.866
epoch:11 | 0.116 - 0.882
epoch:12 | 0.116 - 0.899
epoch:13 | 0.114 - 0.909
epoch:14 | 0.146 - 0.924
epoch:15 | 0.151 - 0.929
epoch:16 | 0.117 - 0.941
epoch:17 | 0.117 - 0.951
epoch:18 | 0.117 - 0.948
epoch:19 | 0.117 - 0.962
============== 7/16 ==============
epoch:0 | 0.097 - 0.102
epoch:1 | 0.113 - 0.26
epoch:2 | 0.116 - 0.632
epoch:3 | 0.116 - 0.731
epoch:4 | 0.116 - 0.768
epoch:5 | 0.116 - 0.79
epoch:6 | 0.116 - 0.819
epoch:7 | 0.116 - 0.852
epoch:8 | 0.116 - 0.878
epoch:9 | 0.116 - 0.896
epoch:10 | 0.116 - 0.912
epoch:11 | 0.116 - 0.925
epoch:12 | 0.116 - 0.934
epoch:13 | 0.116 - 0.949
epoch:14 | 0.116 - 0.954
epoch:15 | 0.116 - 0.967
epoch:16 | 0.116 - 0.969
epoch:17 | 0.116 - 0.975
epoch:18 | 0.116 - 0.98
epoch:19 | 0.116 - 0.983
============== 8/16 ==============
epoch:0 | 0.116 - 0.099
epoch:1 | 0.116 - 0.455
epoch:2 | 0.1 - 0.667
epoch:3 | 0.1 - 0.747
epoch:4 | 0.117 - 0.811
epoch:5 | 0.117 - 0.844
epoch:6 | 0.117 - 0.886
epoch:7 | 0.117 - 0.928
epoch:8 | 0.117 - 0.95
epoch:9 | 0.117 - 0.961
epoch:10 | 0.117 - 0.982
epoch:11 | 0.117 - 0.985
epoch:12 | 0.117 - 0.989
epoch:13 | 0.117 - 0.994
epoch:14 | 0.117 - 0.995
epoch:15 | 0.117 - 0.996
epoch:16 | 0.117 - 0.996
epoch:17 | 0.117 - 0.998
epoch:18 | 0.117 - 0.998
epoch:19 | 0.117 - 0.998
============== 9/16 ==============
epoch:0 | 0.117 - 0.093
epoch:1 | 0.093 - 0.649
epoch:2 | 0.116 - 0.759
epoch:3 | 0.116 - 0.799
epoch:4 | 0.117 - 0.824
epoch:5 | 0.117 - 0.869
epoch:6 | 0.117 - 0.882
epoch:7 | 0.117 - 0.903
epoch:8 | 0.117 - 0.916
epoch:9 | 0.117 - 0.94
epoch:10 | 0.117 - 0.966
epoch:11 | 0.117 - 0.986
epoch:12 | 0.117 - 0.993
epoch:13 | 0.117 - 0.998
epoch:14 | 0.117 - 0.999
epoch:15 | 0.117 - 0.999
epoch:16 | 0.117 - 1.0
epoch:17 | 0.117 - 1.0
epoch:18 | 0.117 - 1.0
epoch:19 | 0.117 - 1.0
============== 10/16 ==============
epoch:0 | 0.094 - 0.188
epoch:1 | 0.094 - 0.64
epoch:2 | 0.117 - 0.692
epoch:3 | 0.116 - 0.84
epoch:4 | 0.116 - 0.905
epoch:5 | 0.116 - 0.951
epoch:6 | 0.116 - 0.922
epoch:7 | 0.116 - 0.976
epoch:8 | 0.116 - 0.981
epoch:9 | 0.116 - 0.98
epoch:10 | 0.116 - 0.992
epoch:11 | 0.116 - 0.991
epoch:12 | 0.117 - 0.993
epoch:13 | 0.116 - 0.995
epoch:14 | 0.116 - 0.996
epoch:15 | 0.116 - 0.998
epoch:16 | 0.116 - 0.997
epoch:17 | 0.116 - 0.998
epoch:18 | 0.116 - 0.999
epoch:19 | 0.116 - 0.999
============== 11/16 ==============
epoch:0 | 0.117 - 0.153
epoch:1 | 0.117 - 0.623
epoch:2 | 0.117 - 0.693
epoch:3 | 0.116 - 0.768
epoch:4 | 0.116 - 0.734
epoch:5 | 0.116 - 0.858
epoch:6 | 0.116 - 0.863
epoch:7 | 0.116 - 0.861
epoch:8 | 0.117 - 0.939
epoch:9 | 0.117 - 0.953
epoch:10 | 0.116 - 0.978
epoch:11 | 0.116 - 0.975
epoch:12 | 0.116 - 0.992
epoch:13 | 0.116 - 0.992
epoch:14 | 0.116 - 0.991
epoch:15 | 0.116 - 0.994
epoch:16 | 0.116 - 0.995
epoch:17 | 0.116 - 0.995
epoch:18 | 0.116 - 0.997
epoch:19 | 0.116 - 0.998
============== 12/16 ==============
epoch:0 | 0.087 - 0.141
epoch:1 | 0.117 - 0.631
epoch:2 | 0.117 - 0.761
epoch:3 | 0.117 - 0.777
epoch:4 | 0.117 - 0.797
epoch:5 | 0.117 - 0.767
epoch:6 | 0.117 - 0.856
epoch:7 | 0.117 - 0.88
epoch:8 | 0.117 - 0.885
epoch:9 | 0.117 - 0.864
epoch:10 | 0.117 - 0.954
epoch:11 | 0.117 - 0.927
epoch:12 | 0.117 - 0.976
epoch:13 | 0.117 - 0.962
epoch:14 | 0.117 - 0.988
epoch:15 | 0.117 - 0.979
epoch:16 | 0.117 - 0.99
epoch:17 | 0.117 - 0.987
epoch:18 | 0.117 - 0.993
epoch:19 | 0.117 - 0.994
============== 13/16 ==============
epoch:0 | 0.099 - 0.164
epoch:1 | 0.099 - 0.285
epoch:2 | 0.117 - 0.487
epoch:3 | 0.116 - 0.526
epoch:4 | 0.116 - 0.673
epoch:5 | 0.116 - 0.686
epoch:6 | 0.116 - 0.683
epoch:7 | 0.117 - 0.677
epoch:8 | 0.116 - 0.701
epoch:9 | 0.116 - 0.677
epoch:10 | 0.117 - 0.709
epoch:11 | 0.117 - 0.71
epoch:12 | 0.117 - 0.7
epoch:13 | 0.117 - 0.711
epoch:14 | 0.117 - 0.709
epoch:15 | 0.117 - 0.71
epoch:16 | 0.117 - 0.714
epoch:17 | 0.116 - 0.717
epoch:18 | 0.116 - 0.809
epoch:19 | 0.116 - 0.81
============== 14/16 ==============
epoch:0 | 0.099 - 0.207
epoch:1 | 0.116 - 0.324
epoch:2 | 0.117 - 0.408
epoch:3 | 0.117 - 0.486
epoch:4 | 0.117 - 0.491
epoch:5 | 0.117 - 0.509
epoch:6 | 0.117 - 0.511
epoch:7 | 0.117 - 0.504
epoch:8 | 0.117 - 0.5
epoch:9 | 0.117 - 0.517
epoch:10 | 0.117 - 0.519
epoch:11 | 0.117 - 0.519
epoch:12 | 0.117 - 0.375
epoch:13 | 0.117 - 0.519
epoch:14 | 0.117 - 0.522
epoch:15 | 0.117 - 0.522
epoch:16 | 0.117 - 0.521
epoch:17 | 0.117 - 0.614
epoch:18 | 0.116 - 0.615
epoch:19 | 0.117 - 0.616
============== 15/16 ==============
epoch:0 | 0.117 - 0.198
epoch:1 | 0.117 - 0.358
epoch:2 | 0.117 - 0.493
epoch:3 | 0.117 - 0.346
epoch:4 | 0.117 - 0.505
epoch:5 | 0.117 - 0.51
epoch:6 | 0.117 - 0.516
epoch:7 | 0.117 - 0.519
epoch:8 | 0.117 - 0.56
epoch:9 | 0.117 - 0.541
epoch:10 | 0.117 - 0.533
epoch:11 | 0.117 - 0.541
epoch:12 | 0.116 - 0.568
epoch:13 | 0.116 - 0.608
epoch:14 | 0.116 - 0.609
epoch:15 | 0.116 - 0.613
epoch:16 | 0.116 - 0.616
epoch:17 | 0.116 - 0.62
epoch:18 | 0.116 - 0.615
epoch:19 | 0.116 - 0.652
============== 16/16 ==============
epoch:0 | 0.092 - 0.092
epoch:1 | 0.094 - 0.288
epoch:2 | 0.116 - 0.373
epoch:3 | 0.116 - 0.407
epoch:4 | 0.116 - 0.416
epoch:5 | 0.116 - 0.418
epoch:6 | 0.116 - 0.488
epoch:7 | 0.117 - 0.493
epoch:8 | 0.117 - 0.502
epoch:9 | 0.117 - 0.517
epoch:10 | 0.117 - 0.52
epoch:11 | 0.117 - 0.507
epoch:12 | 0.117 - 0.524
epoch:13 | 0.117 - 0.521
epoch:14 | 0.117 - 0.523
epoch:15 | 0.117 - 0.522
epoch:16 | 0.117 - 0.522
epoch:17 | 0.116 - 0.523
epoch:18 | 0.116 - 0.481
epoch:19 | 0.116 - 0.509