炼丹软件
github链接:
有需要联系我
requirements:
测试在ubuntu18.04和Windows均可运行
ubuntu18.04
OS: Ubuntu 18.04.6 LTS
Python version: 3.7 (64-bit runtime)
Is CUDA available: True
CUDA runtime version: 11.1.74
GPU models and configuration: GPU 0: NVIDIA GeForce RTX 3090 Ti
Nvidia driver version: 510.108.03
安装可能存在的问题:
No module named ‘kornia’
pip install kornia==0.5
注: 不带版本号会默认下载新的torch
No module named ‘PyQt5.QtChart’
需单独安装
popip install pyqtchart
ModuleNotFoundError: No module named ‘qt_material’
ModuleNotFoundError: No module named ‘jinja2’
pip install jinja2
主函数
if __name__ == '__main__':
# multiprocessing.freeze_support()
# initialize QApplication
app = QApplication(sys.argv)
# set style
app.setStyleSheet(StyleSheet)
# set the launch
global splash
splash = GifSplashScreen()
splash.show()
m_window = MainCode()
apply_stylesheet(app, theme='dark_blue.xml') #, invert_secondary=True)
m_window.show()
sys.exit(app.exec_())
启动界面
通常在主界面加载完成之前提供一个启动界面,减少主程序加载过程用户的等待
# 启动界面
class GifSplashScreen(QSplashScreen):
def __init__(self, *args, **kwargs):
super(GifSplashScreen, self).__init__(*args, **kwargs)
self.movie = QMovie('./Lib/splash.gif')
self.movie.frameChanged.connect(self.onFrameChanged)
self.movie.start()
def onFrameChanged(self, _):
self.setPixmap(self.movie.currentPixmap())
def finish(self, widget):
self.movie.stop()
super(GifSplashScreen, self).finish(widget)
主界面设计
利用QtDesigner来设计界面,通过Pycharm外部工具PyUIC转化成py文件
class Ui_MainWindow(object):
def setupUi(self, MainWindow):
MainWindow.setObjectName("MainWindow")
MainWindow.resize(1000, 800)
icon = QtGui.QIcon()
icon.addPixmap(QtGui.QPixmap("./icon.ico"), QtGui.QIcon.Normal, QtGui.QIcon.Off)
MainWindow.setWindowIcon(icon)
self.centralwidget = QtWidgets.QWidget(MainWindow)
self.centralwidget.setObjectName("centralwidget")
self.horizontalLayout = QtWidgets.QHBoxLayout(self.centralwidget)
self.horizontalLayout.setObjectName("horizontalLayout")
self.toolBox = QtWidgets.QToolBox(self.centralwidget)
self.toolBox.setMaximumSize(QtCore.QSize(200, 16777215))
self.toolBox.setObjectName("toolBox")
self.m_First = QtWidgets.QWidget()
self.m_First.setGeometry(QtCore.QRect(0, 0, 200, 597))
self.m_First.setObjectName("m_First")
self.horizontalLayout_2 = QtWidgets.QHBoxLayout(self.m_First)
self.horizontalLayout_2.setObjectName("horizontalLayout_2")
self.m_ImageDirBtn = QtWidgets.QPushButton(self.m_First)
self.m_ImageDirBtn.setObjectName("m_ImageDirBtn")
self.horizontalLayout_2.addWidget(self.m_ImageDirBtn)
self.toolBox.addItem(self.m_First, "")
self.m_Second = QtWidgets.QWidget()
self.m_Second.setGeometry(QtCore.QRect(0, 0, 200, 597))
self.m_Second.setObjectName("m_Second")
self.horizontalLayout_3 = QtWidgets.QHBoxLayout(self.m_Second)
self.horizontalLayout_3.setObjectName("horizontalLayout_3")
self.m_SetTrainPareBtn = QtWidgets.QPushButton(self.m_Second)
self.m_SetTrainPareBtn.setObjectName("m_SetTrainPareBtn")
self.horizontalLayout_3.addWidget(self.m_SetTrainPareBtn)
self.toolBox.addItem(self.m_Second, "")
self.m_Third = QtWidgets.QWidget()
self.m_Third.setGeometry(QtCore.QRect(0, 0, 200, 597))
self.m_Third.setObjectName("m_Third")
self.verticalLayout = QtWidgets.QVBoxLayout(self.m_Third)
self.verticalLayout.setObjectName("verticalLayout")
self.m_StartTrainBtn = QtWidgets.QPushButton(self.m_Third)
self.m_StartTrainBtn.setObjectName("m_StartTrainBtn")
self.verticalLayout.addWidget(self.m_StartTrainBtn)
self.toolBox.addItem(self.m_Third, "")
self.m_Forth = QtWidgets.QWidget()
self.m_Forth.setGeometry(QtCore.QRect(0, 0, 200, 597))
self.m_Forth.setObjectName("m_Forth")
self.verticalLayout_2 = QtWidgets.QVBoxLayout(self.m_Forth)
self.verticalLayout_2.setObjectName("verticalLayout_2")
self.m_DetectSinglePicBtn = QtWidgets.QPushButton(self.m_Forth)
self.m_DetectSinglePicBtn.setObjectName("m_DetectSinglePicBtn")
self.verticalLayout_2.addWidget(self.m_DetectSinglePicBtn)
self.toolBox.addItem(self.m_Forth, "")
self.m_fifth = QtWidgets.QWidget()
self.m_fifth.setObjectName("m_fifth")
self.m_convertModel = QtWidgets.QPushButton(self.m_fifth)
self.m_convertModel.setGeometry(QtCore.QRect(50, 180, 75, 23))
self.m_convertModel.setObjectName("m_convertModel")
self.toolBox.addItem(self.m_fifth, "")
self.horizontalLayout.addWidget(self.toolBox)
self.tabWidget = QtWidgets.QTabWidget(self.centralwidget)
self.tabWidget.setObjectName("tabWidget")
self.home_page = QtWidgets.QWidget()
self.home_page.setObjectName("home_page")
self.gridLayout_8 = QtWidgets.QGridLayout(self.home_page)
self.gridLayout_8.setObjectName("gridLayout_8")
self.m_homePagelabel = QtWidgets.QLabel(self.home_page)
sizePolicy = QtWidgets.QSizePolicy(QtWidgets.QSizePolicy.Expanding, QtWidgets.QSizePolicy.Expanding)
sizePolicy.setHorizontalStretch(0)
sizePolicy.setVerticalStretch(0)
sizePolicy.setHeightForWidth(self.m_homePagelabel.sizePolicy().hasHeightForWidth())
self.m_homePagelabel.setSizePolicy(sizePolicy)
self.m_homePagelabel.setObjectName("m_homePagelabel")
self.gridLayout_8.addWidget(self.m_homePagelabel, 0, 0, 1, 1)
self.tabWidget.addTab(self.home_page, "")
self.m_FirstW = QtWidgets.QWidget()
self.m_FirstW.setObjectName("m_FirstW")
self.gridLayout = QtWidgets.QGridLayout(self.m_FirstW)
self.gridLayout.setObjectName("gridLayout")
self.le_imageDir = QtWidgets.QLineEdit(self.m_FirstW)
self.le_imageDir.setObjectName("le_imageDir")
self.gridLayout.addWidget(self.le_imageDir, 0, 0, 1, 1)
self.tableImgList = QtWidgets.QTableWidget(self.m_FirstW)
self.tableImgList.setMaximumSize(QtCore.QSize(300, 16777215))
self.tableImgList.setObjectName("tableImgList")
self.tableImgList.setColumnCount(0)
self.tableImgList.setRowCount(0)
self.gridLayout.addWidget(self.tableImgList, 1, 1, 1, 1)
self.la_image = QtWidgets.QLabel(self.m_FirstW)
self.la_image.setObjectName("la_image")
self.gridLayout.addWidget(self.la_image, 1, 0, 1, 1)
self.tabWidget.addTab(self.m_FirstW, "")
self.m_SecondW = QtWidgets.QWidget()
self.m_SecondW.setObjectName("m_SecondW")
self.gridLayout_5 = QtWidgets.QGridLayout(self.m_SecondW)
self.gridLayout_5.setObjectName("gridLayout_5")
self.m_MaxIterationL = QtWidgets.QLabel(self.m_SecondW)
self.m_MaxIterationL.setObjectName("m_MaxIterationL")
self.gridLayout_5.addWidget(self.m_MaxIterationL, 0, 0, 1, 1)
self.m_MaxIterationEd = QtWidgets.QLineEdit(self.m_SecondW)
self.m_MaxIterationEd.setObjectName("m_MaxIterationEd")
self.gridLayout_5.addWidget(self.m_MaxIterationEd, 0, 1, 1, 1)
self.m_BathSizeL = QtWidgets.QLabel(self.m_SecondW)
self.m_BathSizeL.setObjectName("m_BathSizeL")
self.gridLayout_5.addWidget(self.m_BathSizeL, 1, 0, 1, 1)
self.m_BathSizeEd = QtWidgets.QLineEdit(self.m_SecondW)
self.m_BathSizeEd.setObjectName("m_BathSizeEd")
self.gridLayout_5.addWidget(self.m_BathSizeEd, 1, 1, 1, 1)
self.m_ImageSizeL = QtWidgets.QLabel(self.m_SecondW)
self.m_ImageSizeL.setObjectName("m_ImageSizeL")
self.gridLayout_5.addWidget(self.m_ImageSizeL, 2, 0, 1, 1)
self.m_ImageSizeEd = QtWidgets.QLineEdit(self.m_SecondW)
self.m_ImageSizeEd.setObjectName("m_ImageSizeEd")
self.gridLayout_5.addWidget(self.m_ImageSizeEd, 2, 1, 1, 1)
self.m_ValidationRatioL = QtWidgets.QLabel(self.m_SecondW)
self.m_ValidationRatioL.setObjectName("m_ValidationRatioL")
self.gridLayout_5.addWidget(self.m_ValidationRatioL, 3, 0, 1, 1)
self.m_ValidationRatioEd = QtWidgets.QLineEdit(self.m_SecondW)
self.m_ValidationRatioEd.setObjectName("m_ValidationRatioEd")
self.gridLayout_5.addWidget(self.m_ValidationRatioEd, 3, 1, 1, 1)
self.m_LearningRateL = QtWidgets.QLabel(self.m_SecondW)
self.m_LearningRateL.setObjectName("m_LearningRateL")
self.gridLayout_5.addWidget(self.m_LearningRateL, 4, 0, 1, 1)
self.m_LearningRateEd = QtWidgets.QLineEdit(self.m_SecondW)
self.m_LearningRateEd.setObjectName("m_LearningRateEd")
self.gridLayout_5.addWidget(self.m_LearningRateEd, 4, 1, 1, 1)
self.m_WeightDecayL = QtWidgets.QLabel(self.m_SecondW)
self.m_WeightDecayL.setObjectName("m_WeightDecayL")
self.gridLayout_5.addWidget(self.m_WeightDecayL, 5, 0, 1, 1)
self.m_WeightDecayEd = QtWidgets.QLineEdit(self.m_SecondW)
self.m_WeightDecayEd.setObjectName("m_WeightDecayEd")
self.gridLayout_5.addWidget(self.m_WeightDecayEd, 5, 1, 1, 1)
self.m_isCuda = QtWidgets.QCheckBox(self.m_SecondW)
self.m_isCuda.setObjectName("m_isCuda")
self.gridLayout_5.addWidget(self.m_isCuda, 6, 0, 1, 1)
self.m_OkBtn = QtWidgets.QPushButton(self.m_SecondW)
self.m_OkBtn.setObjectName("m_OkBtn")
self.gridLayout_5.addWidget(self.m_OkBtn, 7, 1, 1, 1)
self.tabWidget.addTab(self.m_SecondW, "")
self.m_ThirdW = QtWidgets.QWidget()
self.m_ThirdW.setObjectName("m_ThirdW")
self.gridLayout_7 = QtWidgets.QGridLayout(self.m_ThirdW)
self.gridLayout_7.setObjectName("gridLayout_7")
self.m_trainwidget = QtWidgets.QWidget(self.m_ThirdW)
self.m_trainwidget.setObjectName("m_trainwidget")
self.gridLayout_2 = QtWidgets.QGridLayout(self.m_trainwidget)
self.gridLayout_2.setObjectName("gridLayout_2")
self.m_initModelBtn = QtWidgets.QPushButton(self.m_trainwidget)
self.m_initModelBtn.setObjectName("m_initModelBtn")
self.gridLayout_2.addWidget(self.m_initModelBtn, 0, 0, 1, 1)
self.m_startTrainBtn = QtWidgets.QPushButton(self.m_trainwidget)
self.m_startTrainBtn.setObjectName("m_startTrainBtn")
self.gridLayout_2.addWidget(self.m_startTrainBtn, 0, 1, 1, 1)
self.m_pauseTrainBtn = QtWidgets.QPushButton(self.m_trainwidget)
self.m_pauseTrainBtn.setObjectName("m_pauseTrainBtn")
self.gridLayout_2.addWidget(self.m_pauseTrainBtn, 1, 0, 1, 1)
self.m_resumTrainBtn = QtWidgets.QPushButton(self.m_trainwidget)
self.m_resumTrainBtn.setObjectName("m_resumTrainBtn")
self.gridLayout_2.addWidget(self.m_resumTrainBtn, 1, 1, 1, 1)
self.m_stopTrainBtn = QtWidgets.QPushButton(self.m_trainwidget)
self.m_stopTrainBtn.setObjectName("m_stopTrainBtn")
self.gridLayout_2.addWidget(self.m_stopTrainBtn, 2, 0, 1, 1)
self.gridLayout_7.addWidget(self.m_trainwidget, 0, 1, 1, 1)
self.m_savemodelWidget = QtWidgets.QWidget(self.m_ThirdW)
self.m_savemodelWidget.setMinimumSize(QtCore.QSize(0, 100))
self.m_savemodelWidget.setMaximumSize(QtCore.QSize(16777215, 100))
self.m_savemodelWidget.setObjectName("m_savemodelWidget")
self.gridLayout_3 = QtWidgets.QGridLayout(self.m_savemodelWidget)
self.gridLayout_3.setObjectName("gridLayout_3")
self.m_modelSaveEd = QtWidgets.QLineEdit(self.m_savemodelWidget)
self.m_modelSaveEd.setObjectName("m_modelSaveEd")
self.gridLayout_3.addWidget(self.m_modelSaveEd, 0, 1, 1, 1)
self.m_modelSaveBtn = QtWidgets.QPushButton(self.m_savemodelWidget)
self.m_modelSaveBtn.setObjectName("m_modelSaveBtn")
self.gridLayout_3.addWidget(self.m_modelSaveBtn, 0, 3, 1, 1)
self.m_modelSaveL = QtWidgets.QLabel(self.m_savemodelWidget)
self.m_modelSaveL.setObjectName("m_modelSaveL")
self.gridLayout_3.addWidget(self.m_modelSaveL, 0, 0, 1, 1)
self.gridLayout_7.addWidget(self.m_savemodelWidget, 0, 0, 1, 1)
self.m_modelTrainProcesssbar = QtWidgets.QProgressBar(self.m_ThirdW)
self.m_modelTrainProcesssbar.setProperty("value", 24)
self.m_modelTrainProcesssbar.setObjectName("m_modelTrainProcesssbar")
self.gridLayout_7.addWidget(self.m_modelTrainProcesssbar, 1, 0, 1, 2)
self.textBrowser = QtWidgets.QTextBrowser(self.m_ThirdW)
self.textBrowser.setObjectName("textBrowser")
self.gridLayout_7.addWidget(self.textBrowser, 3, 0, 1, 2)
self.tabWidget.addTab(self.m_ThirdW, "")
self.m_ForthW = QtWidgets.QWidget()
self.m_ForthW.setObjectName("m_ForthW")
self.gridLayout_6 = QtWidgets.QGridLayout(self.m_ForthW)
self.gridLayout_6.setObjectName("gridLayout_6")
self.m_loadmodelwidget = QtWidgets.QWidget(self.m_ForthW)
self.m_loadmodelwidget.setMinimumSize(QtCore.QSize(0, 50))
self.m_loadmodelwidget.setMaximumSize(QtCore.QSize(16777215, 50))
self.m_loadmodelwidget.setObjectName("m_loadmodelwidget")
self.gridLayout_4 = QtWidgets.QGridLayout(self.m_loadmodelwidget)
self.gridLayout_4.setObjectName("gridLayout_4")
self.m_loadmodelBtn = QtWidgets.QPushButton(self.m_loadmodelwidget)
self.m_loadmodelBtn.setObjectName("m_loadmodelBtn")
self.gridLayout_4.addWidget(self.m_loadmodelBtn, 0, 0, 1, 1)
self.m_loadmodelEd = QtWidgets.QLineEdit(self.m_loadmodelwidget)
self.m_loadmodelEd.setObjectName("m_loadmodelEd")
self.gridLayout_4.addWidget(self.m_loadmodelEd, 0, 1, 1, 1)
self.gridLayout_6.addWidget(self.m_loadmodelwidget, 0, 0, 1, 1)
self.la_result = QtWidgets.QLabel(self.m_ForthW)
self.la_result.setObjectName("la_result")
self.gridLayout_6.addWidget(self.la_result, 1, 0, 1, 1)
self.tabWidget.addTab(self.m_ForthW, "")
self.tab = QtWidgets.QWidget()
self.tab.setObjectName("tab")
self.gridLayout_9 = QtWidgets.QGridLayout(self.tab)
self.gridLayout_9.setObjectName("gridLayout_9")
self.m_choosemodelEd = QtWidgets.QLineEdit(self.tab)
self.m_choosemodelEd.setObjectName("m_choosemodelEd")
self.gridLayout_9.addWidget(self.m_choosemodelEd, 0, 0, 1, 1)
self.m_choosemodelBtn = QtWidgets.QPushButton(self.tab)
self.m_choosemodelBtn.setObjectName("m_choosemodelBtn")
self.gridLayout_9.addWidget(self.m_choosemodelBtn, 1, 0, 1, 1)
self.m_starttransformBtn = QtWidgets.QPushButton(self.tab)
self.m_starttransformBtn.setObjectName("m_starttransformBtn")
self.gridLayout_9.addWidget(self.m_starttransformBtn, 2, 0, 1, 1)
self.tabWidget.addTab(self.tab, "")
self.horizontalLayout.addWidget(self.tabWidget)
MainWindow.setCentralWidget(self.centralwidget)
self.menubar = QtWidgets.QMenuBar(MainWindow)
self.menubar.setGeometry(QtCore.QRect(0, 0, 1000, 22))
self.menubar.setObjectName("menubar")
self.openmenu = QtWidgets.QMenu(self.menubar)
self.openmenu.setObjectName("openmenu")
MainWindow.setMenuBar(self.menubar)
self.statusbar = QtWidgets.QStatusBar(MainWindow)
self.statusbar.setObjectName("statusbar")
MainWindow.setStatusBar(self.statusbar)
self.toolBar = QtWidgets.QToolBar(MainWindow)
self.toolBar.setObjectName("toolBar")
MainWindow.addToolBar(QtCore.Qt.BottomToolBarArea, self.toolBar)
self.menubar.addAction(self.openmenu.menuAction())
self.retranslateUi(MainWindow)
self.toolBox.setCurrentIndex(4)
self.tabWidget.setCurrentIndex(5)
QtCore.QMetaObject.connectSlotsByName(MainWindow)
def retranslateUi(self, MainWindow):
_translate = QtCore.QCoreApplication.translate
MainWindow.setWindowTitle(_translate("MainWindow", "训练界面"))
self.m_ImageDirBtn.setText(_translate("MainWindow", "选择图像路径"))
self.toolBox.setItemText(self.toolBox.indexOf(self.m_First), _translate("MainWindow", "第一步"))
self.m_SetTrainPareBtn.setText(_translate("MainWindow", "设置训练参数"))
self.toolBox.setItemText(self.toolBox.indexOf(self.m_Second), _translate("MainWindow", "第二步"))
self.m_StartTrainBtn.setText(_translate("MainWindow", "开始训练"))
self.toolBox.setItemText(self.toolBox.indexOf(self.m_Third), _translate("MainWindow", "第三步"))
self.m_DetectSinglePicBtn.setText(_translate("MainWindow", "检测图像"))
self.toolBox.setItemText(self.toolBox.indexOf(self.m_Forth), _translate("MainWindow", "第四步"))
self.m_convertModel.setText(_translate("MainWindow", "模型转换"))
self.toolBox.setItemText(self.toolBox.indexOf(self.m_fifth), _translate("MainWindow", "第五步"))
self.m_homePagelabel.setText(_translate("MainWindow", "TextLabel"))
self.tabWidget.setTabText(self.tabWidget.indexOf(self.home_page), _translate("MainWindow", "home"))
self.la_image.setText(_translate("MainWindow", "TextLabel"))
self.tabWidget.setTabText(self.tabWidget.indexOf(self.m_FirstW), _translate("MainWindow", "第一步"))
self.m_MaxIterationL.setText(_translate("MainWindow", "最大训练次数:"))
self.m_BathSizeL.setText(_translate("MainWindow", "batch_size(批尺寸)"))
self.m_ImageSizeL.setText(_translate("MainWindow", "图像尺寸"))
self.m_ValidationRatioL.setText(_translate("MainWindow", "验证集比例"))
self.m_LearningRateL.setText(_translate("MainWindow", "学习率"))
self.m_WeightDecayL.setText(_translate("MainWindow", "权重衰减系数:"))
self.m_isCuda.setText(_translate("MainWindow", "是否使用显卡训练"))
self.m_OkBtn.setText(_translate("MainWindow", "OK"))
self.tabWidget.setTabText(self.tabWidget.indexOf(self.m_SecondW), _translate("MainWindow", "第二步"))
self.m_initModelBtn.setText(_translate("MainWindow", "初始化"))
self.m_startTrainBtn.setText(_translate("MainWindow", "开始训练"))
self.m_pauseTrainBtn.setText(_translate("MainWindow", "暂停训练"))
self.m_resumTrainBtn.setText(_translate("MainWindow", "继续训练"))
self.m_stopTrainBtn.setText(_translate("MainWindow", "停止训练"))
self.m_modelSaveBtn.setText(_translate("MainWindow", "选择路径"))
self.m_modelSaveL.setText(_translate("MainWindow", "模型保存位置:"))
self.tabWidget.setTabText(self.tabWidget.indexOf(self.m_ThirdW), _translate("MainWindow", "第三步"))
self.m_loadmodelBtn.setText(_translate("MainWindow", "加载模型:"))
self.la_result.setText(_translate("MainWindow", "TextLabel"))
self.tabWidget.setTabText(self.tabWidget.indexOf(self.m_ForthW), _translate("MainWindow", "第四步"))
self.m_choosemodelBtn.setText(_translate("MainWindow", "选择模型文件"))
self.m_starttransformBtn.setText(_translate("MainWindow", "转换"))
self.tabWidget.setTabText(self.tabWidget.indexOf(self.tab), _translate("MainWindow", "第五步"))
self.openmenu.setTitle(_translate("MainWindow", "打开"))
self.toolBar.setWindowTitle(_translate("MainWindow", "toolBar"))
训练参数类设置
class trainParameter():
def __init__(self):
self.epochs = 0
self.batch_size = 0
self.image_size = 0
self.validation_ratio = 0
self.lr = 0
self.dw = 0
self.Cuda = False
self.kl_weights = 1.0
self.l2_weights = 1.0
self.gms_weights = 1.0
self.ssim_weights = 1.0
多线程设置
class Thread(QThread):
valueChange = pyqtSignal(int)
textChange = pyqtSignal(int, str)
def __init__(self, open_dir, train_parameters, save_dir):
super(Thread, self).__init__()
self._isPause = False
self._isStop = False
self._value = 0
self.cond = QWaitCondition()
self.mutex = QMutex()
self.train = train_parameters
self.open_dir = open_dir
self.save_dir = save_dir
self.xnet = XNet(self.open_dir, self.train, self.save_dir)
def pause(self):
self._isPause = True
def stop(self):
self._isStop = True
def resume(self):
self._isPause = False
self.cond.wakeAll()
def run(self):
for epoch in range(1, self.train.epochs + 1):
self.valueChange.emit(epoch)
self.mutex.lock() # 加锁
if self._isPause:
self.cond.wait(self.mutex)
if self._isStop:
return
QApplication.processEvents() # 实时刷新显示
loss_avg = self.xnet.train_one_epoch()
if np.isnan(loss_avg):
loss_avg = 1e6
# self.valueChanged.emit(loss_avg)
print('Train Epoch: {} loss: {:.6f}'.format(epoch, loss_avg))
self.textChange.emit(epoch, str(loss_avg))
if epoch % 50 == 0:
self.xnet.save_model(epoch) # 保存权重
self.mutex.unlock()
主界面功能
类
class MainCode(QMainWindow, Ui_MainWindow):
def __init__(self):
QMainWindow.__init__(self)
Ui_MainWindow.__init__(self)
self.setupUi(self)
for i in range(2):
sleep(1)
splash.showMessage('加载进度: %d' % i, Qt.AlignHCenter | Qt.AlignBottom, Qt.white)
QApplication.instance().processEvents()
splash.showMessage('初始化完成', Qt.AlignHCenter | Qt.AlignBottom, Qt.white)
splash.finish(self)
self.train = trainParameter()
self.cwd = os.getcwd() # return current workdir
self.initUi()
背景图片展示
def showHomepage(self):
imgName = "./Lib/train.jpeg"
# print(imgName)
jpg = QPixmap(imgName).scaled(self.m_homePagelabel.width(), self.m_homePagelabel.height())
# 显示原图
self.m_homePagelabel.setPixmap(jpg)
self.m_homePagelabel.setScaledContents(True)
右下角显示北京时间
# 获取当前时间
def showCurrentTime(self, timeLabel):
# 获取系统当前时间
time = QDateTime.currentDateTime()
# 设置系统时间的显示格式
timeDisplay = time.toString('yyyy-MM-dd hh:mm:ss dddd')
timeLabel.setText(timeDisplay)
# 状态栏显示时间
def statusShowTime(self):
self.timer = QTimer()
self.timeLabel = QLabel()
self.statusbar.addPermanentWidget(self.timeLabel, 0)
self.timer.timeout.connect(lambda: self.showCurrentTime(self.timeLabel)) # 这个通过调用槽函数来刷新时间
self.timer.start(1000) # 每隔一秒刷新一次,这里设置为1000ms 即1s
选择图像路径并显示在列表中
def openfiledir(self):
openfile_dir = QFileDialog.getExistingDirectory(self.centralwidget, "选择路径", self.cwd)
# print(openfile_dir)
self.le_imageDir.setText(openfile_dir)
# self.le_imageDir.setReadOnly()
# 寻找路径下以*jpg结尾的图像,放进列表中
imgdir = QDir(openfile_dir)
if not imgdir.exists():
return
imgList = imgdir.entryList(['*.bmp', '*.jpg', '*.png'],
QtCore.QDir.NoFilter,
QtCore.QDir.Name | QtCore.QDir.IgnoreCase)
cnt = len(imgList)
self.tableImgList.clear()
self.tableImgList.setRowCount(cnt)
self.tableImgList.setColumnCount(2)
for i in range(0, cnt):
self.tableImgList.setItem(i, 0, QTableWidgetItem(imgList[i]))
self.tableImgList.setItem(i, 1, QTableWidgetItem(imgList[i]))
# # ui->preFileList->setItem(i,0,new QTableWidgetItem(QString::number(i+1)));
# ui->preFileList->setItem(i,1,new QTableWidgetItem(m_fileNameList[i]));
self.tableImgList.setEditTriggers(QAbstractItemView.NoEditTriggers) #禁止编辑
# 取列表的第一张
if cnt > 0:
imgName = openfile_dir +"/" + imgList[0]
# print(imgName)
jpg = QPixmap(imgName).scaled(self.la_image.width(), self.la_image.height())
# 显示原图
self.la_image.setPixmap(jpg)
self.tabWidget.setCurrentIndex(1)
点击列表切换图像
def drawImage(self): # 点击列表中的图像并显示
open_dir = self.le_imageDir.text()
if len(open_dir) == 0:
return
# print(open_dir)
imgFile = open_dir + "/" + self.tableImgList.currentItem().text()
# self.tableImgList.currentItem().background(QBrush(QColor(255, 0, 0))) # 设置选中的单元格颜色
jpg = QPixmap(imgFile).scaled(self.la_image.width(), self.la_image.height())
# 显示原图
self.la_image.setPixmap(jpg)
读取配置文件
支持使用默认配置及手动输入
# 读取配置文件
def readConfig(self):
import yaml
config = "./config.yaml"
with open(config, 'r', encoding='utf8') as file:
d = yaml.safe_load(file.read())
self.m_MaxIterationEd.setText(str(d['epochs']))
# print(self.train.epochs)
self.m_BathSizeEd.setText(str(d['batch_size']))
self.m_ImageSizeEd.setText(str(d['image_size']))
self.m_ValidationRatioEd.setText(str(d['validation_ratio']))
self.m_LearningRateEd.setText(str(d['lr']))
self.m_WeightDecayEd.setText(str(d['dw']))
if d['CUDA']:
self.m_isCuda.setChecked(True)
self.m_modelSaveEd.setText(d['save_dir'])
if not os.path.exists(d['save_dir']):
os.mkdir(d['save_dir'])
手动修改训练超参数
def finishSetting(self):
# 判断不能漏填
if len(self.m_MaxIterationEd.text()) == 0:
QMessageBox.information(self, '警告', '训练次数不能为空')
return
if len(self.m_BathSizeEd.text()) == 0:
QMessageBox.information(self, '警告', '批尺寸不能为空')
return
if len(self.m_ImageSizeEd.text()) == 0:
QMessageBox.information(self, '警告', '图像尺寸不能为空')
return
if len(self.m_ValidationRatioEd.text()) == 0:
QMessageBox.information(self, '警告', '验证集比例不能为空')
return
if len(self.m_LearningRateEd.text()) == 0:
QMessageBox.information(self, '警告', '学习率不能为空')
return
if len(self.m_WeightDecayEd.text()) == 0:
QMessageBox.information(self, '警告', '权重衰减系数不能为空')
return
self.train.epochs = int(self.m_MaxIterationEd.text())
self.train.batch_size = int(self.m_BathSizeEd.text())
self.train.image_size = int(self.m_ImageSizeEd.text())
self.train.validation_ratio = float(self.m_ValidationRatioEd.text())
self.train.lr = float(self.m_LearningRateEd.text())
self.train.dw = float(self.m_WeightDecayEd.text())
self.train.Cuda = self.m_isCuda.isChecked()
# 将内容写到文本文档中
reply = QMessageBox.information(self, "通知", "参数配置完成", QMessageBox.Ok)
if reply == QMessageBox.Ok:
self.doShake()
选择模型保存路径
def selectSaveDir(self):
save_dir = QFileDialog.getExistingDirectory(self.centralwidget, "选择路径", self.cwd)
self.m_modelSaveEd.setText(save_dir)
开始训练
初始化模型
def trainer(self):
self.m_modelTrainProcesssbar.setMaximum(self.train.epochs)
open_dir = self.le_imageDir.text()
if len(open_dir) == 0:
return
save_dir = self.m_modelSaveEd.text()
self.logFile = open(save_dir + '/log.txt', 'a')
self.t = Thread(open_dir,self.train, save_dir)
self.t.valueChange.connect(self.m_modelTrainProcesssbar.setValue)
self.t.textChange.connect(self.updateText)
# self.t.start()
self.m_initModelBtn.setEnabled(False)
self.m_startTrainBtn.setEnabled(True)
开始训练
def startTrainer(self):
self.t.start()
self.m_startTrainBtn.setEnabled(False)
self.m_pauseTrainBtn.setEnabled(True)
中止训练
def suspendTrainer(self):
self.t.pause()
self.m_pauseTrainBtn.setEnabled(False)
self.m_resumTrainBtn.setEnabled(True)
继续训练
def wakeTrainer(self):
self.t.resume()
self.m_pauseTrainBtn.setEnabled(True)
self.m_resumTrainBtn.setEnabled(False)
停止训练
def stopTrainer(self):
reply = QMessageBox.question(self, "停止训练", "您确定要停止训练吗?", QMessageBox.Yes | QMessageBox.No, QMessageBox.No)
if reply == QMessageBox.Yes:
self.t.stop()
self.m_initModelBtn.setEnabled(True)
self.logFile.close()
else:
return
实时显示训练loss
# 多行文本展示框
def updateText(self, val, text):
message = "epoch:{0}".format(val)
msg = message + " " + "loss:" + text
self.logFile.write('\n{}'.format(msg))
self.textBrowser.append(msg)
self.textBrowser.moveCursor(self.textBrowser.textCursor().End)
点击按钮实现窗口抖动
# 点击Ok 后窗口抖动一下
def doShake(self):
self.doShakeWindow(self)
# 下面这个方法可以做成这样的封装给任何控件
def doShakeWindow(self, target):
"""窗口抖动动画
:param target: 目标控件
"""
if hasattr(target, '_shake_animation'): # 函数用于判断对象是否包含对应的属性
# 如果已经有该对象则跳过
return
animation = QPropertyAnimation(target, b'pos', target)
target._shake_animation = animation
animation.finished.connect(lambda: delattr(target, '_shake_animation')) #删除属性
pos = target.pos()
x, y = pos.x(), pos.y()
animation.setDuration(200)
animation.setLoopCount(2)
animation.setKeyValueAt(0, QPoint(x, y))
animation.setKeyValueAt(0.09, QPoint(x + 2, y - 2))
animation.setKeyValueAt(0.18, QPoint(x + 4, y - 4))
animation.setKeyValueAt(0.27, QPoint(x + 2, y - 6))
animation.setKeyValueAt(0.36, QPoint(x + 0, y - 8))
animation.setKeyValueAt(0.45, QPoint(x - 2, y - 10))
animation.setKeyValueAt(0.54, QPoint(x - 4, y - 8))
animation.setKeyValueAt(0.63, QPoint(x - 6, y - 6))
animation.setKeyValueAt(0.72, QPoint(x - 8, y - 4))
animation.setKeyValueAt(0.81, QPoint(x - 6, y - 2))
animation.setKeyValueAt(0.90, QPoint(x - 4, y - 0))
animation.setKeyValueAt(0.99, QPoint(x - 2, y + 2))
animation.setEndValue(QPoint(x, y))
animation.start(animation.DeleteWhenStopped)
信号和槽
def initUi(self):
self.showHomepage()
self.tabWidget.setCurrentIndex(0)
# 设置按钮一闪闪的
aniButton = AnimationShadowEffect(Qt.blue, self.m_ImageDirBtn)
self.m_ImageDirBtn.setGraphicsEffect(aniButton)
aniButton.start()
self.m_ImageDirBtn.clicked.connect(self.openfiledir)
self.tableImgList.itemSelectionChanged.connect(self.drawImage) # 这
self.m_OkBtn.clicked.connect(self.finishSetting) #
self.m_modelSaveBtn.clicked.connect(self.selectSaveDir)
self.m_initModelBtn.clicked.connect(self.trainer)
self.m_startTrainBtn.clicked.connect(self.startTrainer)
self.m_pauseTrainBtn.clicked.connect(self.suspendTrainer)
self.m_resumTrainBtn.clicked.connect(self.wakeTrainer)
self.m_stopTrainBtn.clicked.connect(self.stopTrainer)
self.m_SetTrainPareBtn.clicked.connect(self.second)
self.m_StartTrainBtn.clicked.connect(self.third)
self.m_modelTrainProcesssbar.setValue(0)
self.readConfig()
self.statusShowTime()
self.m_loadmodelBtn.clicked.connect(self.loadmodelPath)
# keyboard.add_hotkey('alt+s', self.onShow, suppress=False) # 显示界面
# keyboard.add_hotkey('ctrl+s', self.onHide, suppress=False) # 隐藏界面
self.m_startTrainBtn.setEnabled(False)
self.m_pauseTrainBtn.setEnabled(False)
self.m_resumTrainBtn.setEnabled(False)
退出软件
def closeEvent(self, event):
"""
对MainWindow的函数closeEvent进行重构
退出软件时结束所有进程
"""
reply = QMessageBox.question(self,
'本程序',
"是否要退出程序?",
QMessageBox.Yes | QMessageBox.No,
QMessageBox.No)
if reply == QMessageBox.Yes:
event.accept()
os._exit(0)
else: