在深度学习中,可以通过学习曲线评估当前训练状态:
- train loss 不断下降,test loss 不断下降,说明网络仍然在认真学习中。
- train loss 不断下降,test loss 趋于不变,说明网络过拟合。
- train loss 趋于不变,test loss 趋于不变,说明学习遇到瓶颈,需减小学习速率或者批量数据尺寸。
- train loss 趋于不变,test loss 不断下降,说明数据集 100% 有问题。
- train loss 不断上升,test loss不断上升(最终为NaN),可能网络结构设计不当、训练超参数设置不当、程序bug等某个问题引起,需要进一步定位。
Linux下的MATLAB代码:
// 提取log文件中的loss值shell命令:cat train_log_file | grep ”Train net output ” | awk ‘{print $11}’
clear;
clc;
close all;
train_log_file = 'train.log';
train_interval = 100;
test_interval = 200;
[~, train_string_output] = dos(['cat ', train_log_file, ' | grep ''Train net output #0'' | awk ''{print $11}''']);
train_loss = str2num(train_string_output);
n = 1 : length(train_loss);
idx_train = (n - 1) * train_interval;
[~, test_string_output] = dos(['cat ', train_log_file, ' | grep ''Test net output #1'' | awk ''{print $11}''']);
test_loss = str2num(test_string_output);
m = 1 : length(test_loss);
idx_test = (m - 1) * test_interval;
figure;
plot(idx_train, train_loss);
hold on;
plot(idx_test, test_loss);
grid on;
legend('Train Loss', 'Test Loss');
xlabel('iterations');
ylabel('loss');
title(' Train & Test Loss Curve');
Window下的Python3(Anaconda3+Pycharm)代码:
"./bin/caffe.exe" train --solver=./examples/mnist/lenet_solver.prototxt >./examples/mnist/log/mnist_Lenet_train_test.log 2>&1
pause
命令>./examples/mnist/log/mnist_Lenet_train_test.log 2>&1
表示训练日志的输出。
parse_log.py和extract_seconds.py文件用于解析训练日志:
parse_log.py源码:
import re
from examples.mnist.log.extract_seconds import *
import csv
from collections import OrderedDict
def parse_log(log_file_name):
"""
Parse log file
:param log_file_name: the name of log file
:return: (train_dict_list, test_dict_list)
"""
regex_iteration = re.compile('Iteration (\d+)')
regex_train_output = re.compile('Train net output #(\d+): (\S+) = ([.\deE+-]+)')
regex_test_output = re.compile('Test net output #(\d+): (\S+) = ([.\deE+-]+)')
regex_learning_rate = re.compile('lr = ([-+]?[0-9]*\.?[0-9]+([eE]?[-+]?[0-9]+)?)')
# Pick out lines of interest
iteration = -1
learning_rate = float('NaN')
train_dict_list = []
test_dict_list = []
train_row = None
test_row = None
logfile_year = get_log_created_year(log_file_name)
with open(log_file_name) as f:
start_time = get_start_time(f, logfile_year)
last_time = start_time
for line in f:
iteration_match = regex_iteration.search(line)
if iteration_match:
iteration = float(iteration_match.group(1))
if iteration == -1:
# Only start parsing for other stuff if we've found the first iteration
continue
try:
time = extract_datetime_from_line(line, logfile_year)
except ValueError:
# Skip lines with bad formatting, for example when resuming solver
continue
# if it's another year
if time.month < last_time.month:
logfile_year += 1
time = extract_datetime_from_line(line, logfile_year)
last_time = time
seconds = (last_time - start_time).total_seconds()
learning_rate_match = regex_learning_rate.search(line)
if learning_rate_match:
learning_rate = float(learning_rate_match.group(1))
train_dict_list, train_row = parse_line_for_net_output(
regex_train_output, train_row, train_dict_list, line, iteration, seconds, learning_rate)
test_dict_list, test_row = parse_line_for_net_output(
regex_test_output, test_row, test_dict_list, line, iteration, seconds, learning_rate)
fix_initial_nan_learning_rate(train_dict_list)
fix_initial_nan_learning_rate(test_dict_list)
return train_dict_list, test_dict_list
def parse_line_for_net_output(regex_obj, row, row_dict_list, line, iteration, seconds, learning_rate):
"""Parse a single line for training or test output
Returns a a tuple with (row_dict_list, row)
row: may be either a new row or an augmented version of the current row
row_dict_list: may be either the current row_dict_list or an augmented
version of the current row_dict_list
"""
output_match = regex_obj.search(line)
if output_match:
if not row or row['NumIters'] != iteration:
# Push the last row and start a new one
if row:
# If we're on a new iteration, push the last row
# This will probably only happen for the first row; otherwise
# the full row checking logic below will push and clear full
# rows
row_dict_list.append(row)
row = OrderedDict(
[
('NumIters', iteration),
('Seconds', seconds),
('LearningRate', learning_rate)
]
)
# output_num is not used; may be used in the future
output_name = output_match.group(2)
output_val = output_match.group(3)
row[output_name] = float(output_val)
if row and len(row_dict_list) >= 1 and len(row) == len(row_dict_list[0]):
# The row is full, based on the fact that it has the same number of columns as the first row;
# append it to the list
row_dict_list.append(row)
row = None
return row_dict_list, row
def fix_initial_nan_learning_rate(dict_list):
"""Correct initial value of learning rate
Learning rate is normally not printed until after the initial test and
training step, which means the initial testing and training rows have
LearningRate = NaN. Fix this by copying over the LearningRate from the
second row, if it exists.
"""
if len(dict_list) > 1:
dict_list[0]['LearningRate'] = dict_list[1]['LearningRate']
def save_csv_files(logfile, output_dir, train_dict_list, test_dict_list, delimiter=',', verbose=False):
"""Save CSV files to output_dir
If the input log file is, e.g., caffe.INFO, the names will be
caffe.INFO.train and caffe.INFO.test
"""
log_basename = os.path.basename(logfile)
train_filename = os.path.join(output_dir, log_basename + '.train')
write_csv(train_filename, train_dict_list, delimiter, verbose)
test_filename = os.path.join(output_dir, log_basename + '.test')
write_csv(test_filename, test_dict_list, delimiter, verbose)
def write_csv(output_filename, dict_list, delimiter, verbose=False):
"""Write a CSV file
"""
if not dict_list:
if verbose:
print('Not writing %s; no lines to write' % output_filename)
return
dialect = csv.excel
dialect.delimiter = delimiter
with open(output_filename, 'w') as f:
dict_writer = csv.DictWriter(f, fieldnames=dict_list[0].keys(),dialect=dialect)
dict_writer.writeheader()
dict_writer.writerows(dict_list)
if verbose:
print('Wrote %s' % output_filename)
def main():
log_file_name = 'mnist_Lenet_train_test.log'
output_dir = 'C:\\Programming Code\\Caffe\\examples\\mnist\\log\\'//解析后的文件保存地址
train_dict_list, test_dict_list = parse_log(log_file_name)
save_csv_files(log_file_name, output_dir, train_dict_list, test_dict_list, delimiter=',')
if __name__ == '__main__':
main()
extract_seconds.py源码:
import datetime
import os
def extract_datetime_from_line(line, year):
"""
extract datetime from line
:param line: the lines
:param year: the year
:return: datetime
"""
# Expected format: I0210 13:39:22.381027 25210 solver.cpp:204] Iteration 100, lr = 0.00992565
line = line.strip().split()
month = int(line[0][1:3])
day = int(line[0][3:])
timestamp = line[1]
pos = timestamp.rfind('.')
ts = [int(x) for x in timestamp[:pos].split(':')]
hour = ts[0]
minute = ts[1]
second = ts[2]
microsecond = int(timestamp[pos + 1:])
dt = datetime.datetime(year, month, day, hour, minute, second, microsecond)
return dt
def get_log_created_year(input_file):
"""
get the year from log file system timestamp
:param input_file: the input
:return: the created year of the log file
"""
log_created_time = os.path.getctime(input_file)
log_created_year = datetime.datetime.fromtimestamp(log_created_time).year
return log_created_year
def get_start_time(line_iterable, year):
"""
find start time from group of lines
:param line_iterable: the lines of log file
:param year: the created year of log file
:return: the start datetime
"""
start_datetime = None
for line in line_iterable:
line = line.strip()
if line.find('Solving') != -1:
start_datetime = extract_datetime_from_line(line, year)
break
return start_datetime
绘图源码:
import matplotlib.pyplot as plt
import random
import itertools
def load_data(data_file, phase):
"""
load the data
:param data_file: the data file
:param phase: the data of train phase or test phase
:return: data
"""
if phase == 'Train':
data = [[], [], []]
with open(data_file, 'r') as f:
for line in itertools.islice(f, 2, None, 2):
line = line.strip()
fields = line.split(",")
data[0].append(float(fields[0].strip()))
data[1].append(float(fields[2].strip()))
data[2].append(float(fields[3].strip()))
else:
data = [[], [], [], []]
with open(data_file, 'r') as f:
for line in itertools.islice(f, 2, None, 2):
line = line.strip()
fields = line.split(",")
data[0].append(float(fields[0].strip()))
data[1].append(float(fields[2].strip()))
data[2].append(float(fields[3].strip()))
data[3].append(float(fields[4].strip()))
return data
def plot_chart(path_to_png, data, phase):
"""
plot the chart according the log file
:param path_to_png: the save path of the png chart
:param data: the data of chart
:param phase: plot the chart of train phase or test phase
:return: None
"""
line_width = 1.0 # the line width
if phase == 'Train':
train_num_iteration = data[0]
train_learning_rate = data[1]
train_loss = data[2]
# plot the Iteration VS Loss
train_color = [random.random(), random.random(), random.random()] # the color of line
figure_1 = plt.figure('Train Iterations VS Loss')
plt.plot(train_num_iteration, train_loss, color=train_color, linewidth=line_width)
plt.title('Train Iterations VS Loss')
plt.xlabel('Iterations')
plt.ylabel('Loss')
plt.savefig(path_to_png + 'Train Iterations VS Loss.png')
# plot the Iteration VS learning rate
train_color = [random.random(), random.random(), random.random()] # the color of line
figure_2 = plt.figure('Train Iterations VS LearningRate')
plt.plot(train_num_iteration, train_learning_rate, color=train_color, linewidth=line_width)
plt.title('Train Iterations VS LearningRate')
plt.xlabel('Iterations')
plt.ylabel('LearningRate')
plt.savefig(path_to_png + 'Train Iterations VS LearningRate.png')
else:
test_num_iteration = data[0]
test_learning_rate = data[1]
test_accuracy = data[2]
test_loss = data[3]
# plot the Iteration VS Loss
test_color = [random.random(), random.random(), random.random()] # the color of line
figure_1 = plt.figure('Test Iterations VS Loss')
plt.plot(test_num_iteration, test_loss, color=test_color, linewidth=line_width)
plt.title('Test Iterations VS Loss')
plt.xlabel('Iterations')
plt.ylabel('Loss')
plt.savefig(path_to_png + 'Test Iterations VS Loss.png')
# plot the Iteration VS LearningRate
test_color = [random.random(), random.random(), random.random()] # the color of line
figure_2 = plt.figure('Test Iterations VS LearningRate')
plt.plot(test_num_iteration, test_learning_rate, color=test_color, linewidth=line_width)
plt.title('Test Iterations VS LearningRate')
plt.xlabel('Iterations')
plt.ylabel('LearningRate')
plt.savefig(path_to_png + 'Test Iterations VS LearningRate.png')
# plot the Iteration VS Accuracy
test_color = [random.random(), random.random(), random.random()] # the color of line
figure_3 = plt.figure('Test Iterations VS Accuracy')
plt.plot(test_num_iteration, test_accuracy, color=test_color, linewidth=line_width)
plt.title('Test Iterations VS Accuracy')
plt.xlabel('Iterations')
plt.ylabel('Accuracy')
plt.savefig(path_to_png + 'Test Iterations VS Accuracy.png')
def main():
train_log = 'mnist_Lenet_train_test.log.train'
test_log = 'mnist_Lenet_train_test.log.test'
path_to_png = 'C:\\Programming Code\\Caffe\\examples\\mnist\\log\\'
# load the train data
train_data = load_data(train_log, phase='Train')
# plot the train chart
plot_chart(path_to_png, train_data, phase='Train')
# load the test data
test_data = load_data(test_log, phase='Test')
# plot the test chart
plot_chart(path_to_png, test_data, phase='Test')
if __name__ == '__main__':
main()