gevent协程、select IO多路复用、socketserver模块 改造多用户FTP程序例子

时间:2022-05-22 20:47:06

原多线程版FTP程序:http://www.cnblogs.com/linzetong/p/8290378.html

只需要在原来的代码基础上稍作修改:

一、gevent协程版本

1、 导入gevent模块

import gevent

2、python的异步库gevent打猴子补丁,他的用途是让你方便的导入非阻塞的模块,不需要特意的去引入。

#注意:from gevent import monkey;monkey.patch_all()必须放到被打补丁者的前面,如time,socket模块之前

from gevent import monkey;monkey.patch_all()

3、 把socket设置为非阻塞

self.sock.setblocking(0)  

4、 修改run函数,

# gevent 实现单线程多并发

gevent.spawn(TCPHandler.handle, TCPHandler(), self.request, self.cli_addr)

其他不用更改

二、select IO多路复用版本

1、 导入select模块

import select

2、 把socket设置为非阻塞

self.sock.setblocking(0)  

3、 修改run函数,用select.select()方法接收并监控多个通信socket列表

def run(self):
while True: # 链接循环
# select 单进程实现同时处理请求
inputs = [self.sock, ]
outputs = []
while True:
readable, writeable, exceptional = select.select(inputs, outputs, inputs)
for r in readable:
if r is self.sock:
request, client_address = self.sock.accept()
inputs.append(request)
else:
print('处理request:%s'%id(r))
return_code, request = TCPHandler().handle(r, )
if not return_code:
request.close()
inputs.remove(request)
print('client[%s] is disconect' % ((request.getpeername()),))

4、完整代码:

server.py

 # -*- coding: utf-8 -*-
import socket
import os, json, re, struct, threading, time
import gevent
from gevent import monkey
import select
from lib import commons
from conf import settings
from core import logger monkey.patch_all() class Server(object):
def __init__(self):
self.init_dir()
self.sock = socket.socket(family=socket.AF_INET, type=socket.SOCK_STREAM)
self.sock.setblocking(0) # select实现同时处理请求,需要设置为非阻塞
# self.sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
self.sock.bind((settings.server_bind_ip, settings.server_bind_port))
self.sock.listen(settings.server_listen)
print("\033[42;1mserver started sucessful!\033[0m")
self.run() @staticmethod
def init_dir():
if not os.path.exists(os.path.join(settings.base_path, 'logs')): os.mkdir(
os.path.join(settings.base_path, 'logs'))
if not os.path.exists(os.path.join(settings.base_path, 'db')): os.mkdir(os.path.join(settings.base_path, 'db'))
if not os.path.exists(os.path.join(settings.base_path, 'home')): os.mkdir(
os.path.join(settings.base_path, 'home')) def run(self):
while True: # 链接循环
# select 单进程实现同时处理请求
inputs = [self.sock, ]
outputs = []
while True:
readable, writeable, exceptional = select.select(inputs, outputs, inputs)
for r in readable:
if r is self.sock:
request, client_address = self.sock.accept()
inputs.append(request)
else:
print('处理request:%s'%id(r))
return_code, request = TCPHandler().handle(r, )
if not return_code:
request.close()
print('client[%s] is disconect' % ((request.getpeername()),))
# self.request, self.cli_addr = self.sock.accept()
# self.request.settimeout(300)
# 多线程处理请求
# thread = threading.Thread(target=TCPHandler.handle, args=(TCPHandler(), self.request, self.cli_addr))
# thread.start()
# gevent 实现单线程多并发
# gevent.spawn(TCPHandler.handle, TCPHandler(), self.request, self.cli_addr) class TCPHandler(object):
STATUS_CODE = {
200: 'Passed authentication!',
201: 'Wrong username or password!',
202: 'Username does not exist!',
300: 'cmd successful , the target path be returned in returnPath',
301: 'cmd format error!',
302: 'The path or file could not be found!',
303: 'The dir is exist',
304: 'The file has been downloaded or the size of the file is exceptions',
305: 'Free space is not enough',
401: 'File MD5 inspection failed',
400: 'File MD5 inspection success',
} def __init__(self):
self.server_logger = logger.logger('server')
self.server_logger.debug("server TCPHandler started successful!") def handle(self, request, address=(None, None)):
self.request = request
self.cli_addr = request.getpeername()
self.server_logger.info('client[%s] is conecting' % ((request.getpeername()),))
print('client[%s] is conecting' % ((request.getpeername()),))
# while True: # 通讯循环
try:
# 1、接收客户端的ftp命令
print("waiting receive client[%s] ftp command.." % ((request.getpeername()),), id(self), self)
header_dic, req_dic = self.recv_request()
if not header_dic: return False, request
if not header_dic['cmd']: return False, request
print('receive client ftp command:%s' % header_dic['cmd'])
# 2、解析ftp命令,获取相应命令参数(文件名)
cmds = header_dic['cmd'].split() # ['register',]、['get', 'a.txt']
if hasattr(self, cmds[0]):
self.server_logger.info('interface:[%s], request:{client:[%s:%s] action:[%s]}' % (
cmds[0], self.cli_addr[0], self.cli_addr[1], header_dic['cmd']))
getattr(self, cmds[0])(header_dic, req_dic)
return True, request
except (ConnectionResetError, ConnectionAbortedError):
return False, request
except socket.timeout:
print('time out %s' % ((request.getpeername()),))
return False, request
# self.request.close()
# self.server_logger.info('client %s is disconect' % ((self.cli_addr,)))
# print('client[%s:%s] is disconect' % (self.cli_addr[0], self.cli_addr[1])) def unpack_header(self):
try:
pack_obj = self.request.recv(4)
header_size = struct.unpack('i', pack_obj)[0]
header_bytes = self.request.recv(header_size)
header_json = header_bytes.decode('utf-8')
header_dic = json.loads(header_json)
return header_dic
except struct.error: # 避免客户端发送错误格式的header_size
return def unpack_info(self, info_size):
recv_size = 0
info_bytes = b''
while recv_size < info_size:
res = self.request.recv(1024)
info_bytes += res
recv_size += len(res)
info_json = info_bytes.decode('utf-8')
info_dic = json.loads(info_json) # {'username':ton, 'password':123}
info_md5 = commons.getStrsMd5(info_bytes)
return info_dic, info_md5 def recv_request(self):
header_dic = self.unpack_header() # {'cmd':'register','info_size':0}
if not header_dic: return None, None
req_dic, info_md5 = self.unpack_info(header_dic['info_size'])
if header_dic.get('md5'):
# 校检请求内容md5一致性
if info_md5 == header_dic['md5']:
pass
# print('\033[42;1m请求内容md5校检结果一致\033[0m')
else:
pass
# print('\033[31;1m请求内容md5校检结果不一致\033[0m')
return header_dic, req_dic def response(self, **kwargs):
rsp_info = kwargs
rsp_bytes = commons.getDictBytes(rsp_info)
md5 = commons.getStrsMd5(rsp_bytes)
header_size_pack, header_bytes = commons.make_header(info_size=len(rsp_bytes), md5=md5)
self.request.sendall(header_size_pack)
self.request.sendall(header_bytes)
self.request.sendall(rsp_bytes) def register(self, header_dic, req_dic): # {'cmd':'register','info_size':0,'resultCode':0,'resultDesc':None}
username = req_dic['user_info']['username']
# 更新数据库,并制作响应信息字典
if not os.path.isfile(os.path.join(settings.db_file, '%s.json' % username)):
# 更新数据库
user_info = dict()
user_info['username'] = username
user_info['password'] = req_dic['user_info']['password']
user_info['home'] = os.path.join(settings.user_home_dir, username)
user_info['quota'] = settings.user_quota * (1024 * 1024)
commons.save_to_file(user_info, os.path.join(settings.db_file, '%s.json' % username))
resultCode = 0
resultDesc = None
# 创建家目录
if not os.path.exists(os.path.join(settings.user_home_dir, username)):
os.mkdir(os.path.join(settings.user_home_dir, username))
self.server_logger.info('client[%s:%s] 注册用户[%s]成功' % (self.cli_addr[0], self.cli_addr[1], username))
else:
resultCode = 1
resultDesc = '该用户已存在,注册失败'
self.server_logger.warning('client[%s:%s] 注册用户[%s]失败:%s' % (self.cli_addr[0], self.cli_addr[1],
username, resultDesc))
# 响应客户端注册请求
self.response(resultCode=resultCode, resultDesc=resultDesc) @staticmethod
def auth(req_dic):
# print(req_dic['user_info'])
user_info = None
status_code = 201
try:
req_username = req_dic['user_info']['username']
db_file = os.path.join(settings.db_file, '%s.json' % req_username)
# 验证用户名密码,并制作响应信息字典
if not os.path.isfile(db_file):
status_code = 202
else:
with open(db_file, 'r') as f:
user_info_db = json.load(f)
if user_info_db['password'] == req_dic['user_info']['password']:
status_code = 200
user_info = user_info_db
return status_code, user_info
# 捕获 客户端鉴权请求时发送一个空字典或错误的字典 的异常
except KeyError:
return 201, user_info def login(self, header_dic, req_dic):
# 鉴权
status_code, user_info = self.auth(req_dic)
# 响应客户端登陆请求
self.response(user_info=user_info, resultCode=status_code) def query_quota(self, header_dic, req_dic):
used_quota = None
total_quota = None
# 鉴权
status_code, user_info = self.auth(req_dic)
# 查询配额
if status_code == 200:
used_quota = commons.getFileSize(user_info['home'])
total_quota = user_info['quota']
# 响应客户端配额查询请求
self.response(resultCode=status_code, total_quota=total_quota, used_quota=used_quota) @staticmethod
def parse_file_path(req_path, cur_path):
req_path = req_path.replace(r'/', '\\')
req_path = req_path.replace(r'//', r'/', )
req_path = req_path.replace('\\\\', '\\')
req_path = req_path.replace('~\\', '', 1)
req_path = req_path.replace(r'~', '', 1)
req_paths = re.findall(r'[^\\]+', req_path)
cur_paths = re.findall(r'[^\\]+', cur_path)
cur_paths.extend(req_paths)
cur_paths[0] += '\\'
while '.' in cur_paths:
cur_paths.remove('.')
while '..' in cur_paths:
for index, item in enumerate(cur_paths):
if item == '..':
cur_paths.pop(index)
cur_paths.pop(index - 1)
break
return cur_paths def cd(self, header_dic, req_dic):
cmds = header_dic['cmd'].split()
# 鉴权
status_code, user_info = self.auth(req_dic)
home = os.path.join(settings.user_home_dir, user_info['username'])
# 先定义响应信息
returnPath = req_dic['user_info']['cur_path']
if status_code == 200:
if len(cmds) != 1:
# 解析cd的真实路径
cur_path = os.path.join(settings.user_home_dir, req_dic['user_info']['cur_path'])
cd_path = os.path.join('', *self.parse_file_path(cmds[1], cur_path))
print('cd解析后的路径:', cd_path)
if os.path.isdir(cd_path):
if home in cd_path:
resultCode = 300
returnPath = cd_path.replace('%s\\' % settings.user_home_dir, '', 1)
else:
resultCode = 302
else:
resultCode = 302
else:
resultCode = 301
else:
resultCode = 201
# 响应客户端的cd命令结果
print('cd发送给客户端的路径:', returnPath)
self.response(resultCode=resultCode, returnPath=returnPath) def ls(self, header_dic, req_dic):
cmds = header_dic['cmd'].split()
# 鉴权
status_code, user_info = self.auth(req_dic)
home = os.path.join(settings.user_home_dir, user_info['username'])
# 先定义响应信息
returnFilenames = None
if status_code == 200:
if len(cmds) <= 2:
# 解析ls的真实路径
cur_path = os.path.join(settings.user_home_dir, req_dic['user_info']['cur_path'])
if len(cmds) == 2:
ls_path = os.path.join('', *self.parse_file_path(cmds[1], cur_path))
else:
ls_path = cur_path
print('ls解析后的路径:', ls_path)
if os.path.isdir(ls_path):
if home in ls_path:
returnCode, filenames = commons.getFile(ls_path, home)
resultCode = 300
returnFilenames = filenames
else:
resultCode = 302
else:
resultCode = 302
else:
resultCode = 301
else:
resultCode = 201
# 响应客户端的ls命令结果
time.sleep(5)
self.response(resultCode=resultCode, returnFilenames=returnFilenames) def rm(self, header_dic, req_dic):
cmds = header_dic['cmd'].split()
# 鉴权
status_code, user_info = self.auth(req_dic)
home = os.path.join(settings.user_home_dir, user_info['username'])
# 先定义响应信息
if status_code == 200:
if len(cmds) == 2:
# 解析rm的真实路径
cur_path = os.path.join(settings.user_home_dir, req_dic['user_info']['cur_path'])
rm_path = os.path.join('', *self.parse_file_path(os.path.dirname(cmds[1]), cur_path))
rm_file = os.path.join(rm_path, os.path.basename(cmds[1]))
print('rm解析后的文件或文件夹:', rm_file)
if os.path.exists(rm_file):
if home in rm_file:
commons.rmdirs(rm_file)
resultCode = 300
else:
resultCode = 302
else:
resultCode = 302
else:
resultCode = 301
else:
resultCode = 201
# 响应客户端的rm命令结果
self.response(resultCode=resultCode) def mkdir(self, header_dic, req_dic):
cmds = header_dic['cmd'].split()
# 鉴权
status_code, user_info = self.auth(req_dic)
home = os.path.join(settings.user_home_dir, user_info['username'])
# 先定义响应信息
if status_code == 200:
if len(cmds) == 2:
# 解析rm的真实路径
cur_path = os.path.join(settings.user_home_dir, req_dic['user_info']['cur_path'])
mkdir_path = os.path.join('', *self.parse_file_path(cmds[1], cur_path))
print('mkdir解析后的文件夹:', mkdir_path)
if not os.path.isdir(mkdir_path):
if home in mkdir_path:
os.makedirs(mkdir_path)
resultCode = 300
else:
resultCode = 302
else:
resultCode = 303
else:
resultCode = 301
else:
resultCode = 201
# 响应客户端的mkdir命令结果
self.response(resultCode=resultCode) def get(self, header_dic, req_dic):
"""客户端下载文件"""
cmds = header_dic['cmd'].split() # ['get', 'a.txt', 'download']
get_file = None
# 鉴权
status_code, user_info = self.auth(req_dic)
home = os.path.join(settings.user_home_dir, user_info['username'])
# 解析断点续传信息
position = 0
if req_dic['resume'] and isinstance(req_dic['position'], int):
position = req_dic['position']
# 先定义响应信息
resultCode = 300
FileSize = None
FileMd5 = None
if status_code == 200:
if 1 < len(cmds) < 4:
# 解析需要get文件的真实路径
cur_path = os.path.join(settings.user_home_dir, req_dic['user_info']['cur_path'])
get_file = os.path.join('', *self.parse_file_path(cmds[1], cur_path))
print('get解析后的路径:', get_file)
if os.path.isfile(get_file):
if home in get_file:
FileSize = commons.getFileSize(get_file)
if position >= FileSize != 0:
resultCode = 304
else:
resultCode = 300
FileSize = FileSize
FileMd5 = commons.getFileMd5(get_file)
else:
resultCode = 302
else:
resultCode = 302
else:
resultCode = 301
else:
resultCode = 201
# 响应客户端的get命令结果
self.response(resultCode=resultCode, FileSize=FileSize, FileMd5=FileMd5)
if resultCode == 300:
# 发送文件数据
with open(get_file, 'rb') as f:
f.seek(position)
for line in f:
self.request.send(line) def put(self, header_dic, req_dic):
cmds = header_dic['cmd'].split() # ['put', 'download/a.txt', 'video']
put_file = None
# 鉴权
status_code, user_info = self.auth(req_dic)
home = os.path.join(settings.user_home_dir, user_info['username'])
# 查询配额
used_quota = commons.getFileSize(user_info['home'])
total_quota = user_info['quota']
# 先定义响应信息
if status_code == 200:
if 1 < len(cmds) < 4:
# 解析需要put文件的真实路径
cur_path = os.path.join(settings.user_home_dir, req_dic['user_info']['cur_path'])
if len(cmds) == 3:
put_file = os.path.join(os.path.join('', *self.parse_file_path(cmds[2], cur_path)),
os.path.basename(cmds[1]))
else:
put_file = os.path.join(cur_path, os.path.basename(cmds[1]))
print('put解析后的文件:', put_file)
put_path = os.path.dirname(put_file)
if os.path.isdir(put_path):
if home in put_path:
if (req_dic['FileSize'] + used_quota) <= total_quota:
resultCode = 300
else:
resultCode = 305
else:
resultCode = 302
else:
resultCode = 302
else:
resultCode = 301
else:
resultCode = 201
# 响应客户端的put命令结果
self.response(resultCode=resultCode)
if resultCode == 300:
# 接收文件数据,写入文件
recv_size = 0
with open(put_file, 'wb') as f:
while recv_size < req_dic['FileSize']:
file_data = self.request.recv(1024)
f.write(file_data)
recv_size += len(file_data)
# 校检文件md5一致性
if commons.getFileMd5(put_file) == req_dic['FileMd5']:
resultCode = 400
print('\033[42;1m文件md5校检结果一致\033[0m')
print('\033[42;1m文件上传成功,大小:%d,文件名:%s\033[0m' % (req_dic['FileSize'], put_file))
else:
os.remove(put_file)
resultCode = 401
print('\033[31;1m文件md5校检结果不一致\033[0m')
print('\033[42;1m文件上传失败\033[0m')
# 返回上传文件是否成功响应
self.response(resultCode=resultCode)

server.py

三、socketserver模块(内部采用seletors模块)实现并发效果(Linux支持epoll模型)

1、导入 socketserver模块

import socketserver

2、TCPHandler类继承socketsever.BaseRequestHandler

class TCPHandler(socketserver.BaseRequestHandler):

3、必须重写handle函数

4、完整代码:

server.py

 # -*- coding: utf-8 -*-
# from gevent import monkey;monkey.patch_all()
import socketserver
import socket
import os, json, re, struct, threading, time
import gevent
import select
from lib import commons
from conf import settings
from core import logger class Server(object):
def __init__(self):
self.init_dir()
# self.sock = socket.socket(family=socket.AF_INET, type=socket.SOCK_STREAM)
# # self.sock.setblocking(0) # select实现同时处理请求,需要设置为非阻塞
# # self.sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
# self.sock.bind((settings.server_bind_ip, settings.server_bind_port))
# self.sock.listen(settings.server_listen)
print("\033[42;1mserver started sucessful!\033[0m")
self.run() @staticmethod
def init_dir():
if not os.path.exists(os.path.join(settings.base_path, 'logs')): os.mkdir(
os.path.join(settings.base_path, 'logs'))
if not os.path.exists(os.path.join(settings.base_path, 'db')): os.mkdir(os.path.join(settings.base_path, 'db'))
if not os.path.exists(os.path.join(settings.base_path, 'home')): os.mkdir(
os.path.join(settings.base_path, 'home')) def run(self):
server = socketserver.ThreadingTCPServer((settings.server_bind_ip, settings.server_bind_port), TCPHandler)
server.allow_reuse_address = True
server.serve_forever()
# while True: # 链接循环
# # select 单进程实现同时处理请求
# inputs = [self.sock, ]
# outputs = []
# while True:
# readable, writeable, exceptional = select.select(inputs, outputs, inputs)
# for r in readable:
# if r is self.sock:
# request, client_address = self.sock.accept()
# inputs.append(request)
# else:
# print('处理request:%s'%id(r))
# return_code, request = TCPHandler().handle(r, )
# if not return_code:
# request.close()
# print('client[%s] is disconect' % ((request.getpeername()),))
# self.request, self.client_address = self.sock.accept()
# self.request.settimeout(300)
# # 多线程处理请求
# thread = threading.Thread(target=TCPHandler.handle, args=(TCPHandler(), self.request, self.client_address))
# thread.start()
# # gevent 实现单线程多并发
# # gevent.spawn(TCPHandler.handle, TCPHandler(), self.request, self.client_address) class TCPHandler(socketserver.BaseRequestHandler):
STATUS_CODE = {
200: 'Passed authentication!',
201: 'Wrong username or password!',
202: 'Username does not exist!',
300: 'cmd successful , the target path be returned in returnPath',
301: 'cmd format error!',
302: 'The path or file could not be found!',
303: 'The dir is exist',
304: 'The file has been downloaded or the size of the file is exceptions',
305: 'Free space is not enough',
401: 'File MD5 inspection failed',
400: 'File MD5 inspection success',
} def __init__(self, request, client_address, server):
self.server_logger = logger.logger('server')
self.server_logger.debug("server TCPHandler started successful!")
super().__init__(request, client_address, server) def handle(self):
# self.request = request
# self.client_address = self.client_address
self.server_logger.info('client[%s] is conecting' % ((self.client_address,)))
print('client[%s] is conecting' % ((self.client_address,)))
while True: # 通讯循环
try:
# 1、接收客户端的ftp命令
print("waiting receive client[%s] ftp command.." % ((self.client_address,)), id(self), self)
header_dic, req_dic = self.recv_request()
if not header_dic:break
if not header_dic['cmd']: break
print('receive client ftp command:%s' % header_dic['cmd'])
# 2、解析ftp命令,获取相应命令参数(文件名)
cmds = header_dic['cmd'].split() # ['register',]、['get', 'a.txt']
if hasattr(self, cmds[0]):
self.server_logger.info('interface:[%s], request:{client:[%s:%s] action:[%s]}' % (
cmds[0], self.client_address[0], self.client_address[1], header_dic['cmd']))
getattr(self, cmds[0])(header_dic, req_dic)
except (ConnectionResetError, ConnectionAbortedError):break
except socket.timeout:
print('time out %s' % ((self.client_address,)))
break
self.request.close()
self.server_logger.info('client %s is disconect' % ((self.client_address,)))
print('client[%s:%s] is disconect' % (self.client_address[0], self.client_address[1])) def unpack_header(self):
try:
pack_obj = self.request.recv(4)
header_size = struct.unpack('i', pack_obj)[0]
header_bytes = self.request.recv(header_size)
header_json = header_bytes.decode('utf-8')
header_dic = json.loads(header_json)
return header_dic
except struct.error: # 避免客户端发送错误格式的header_size
return def unpack_info(self, info_size):
recv_size = 0
info_bytes = b''
while recv_size < info_size:
res = self.request.recv(1024)
info_bytes += res
recv_size += len(res)
info_json = info_bytes.decode('utf-8')
info_dic = json.loads(info_json) # {'username':ton, 'password':123}
info_md5 = commons.getStrsMd5(info_bytes)
return info_dic, info_md5 def recv_request(self):
header_dic = self.unpack_header() # {'cmd':'register','info_size':0}
if not header_dic: return None, None
req_dic, info_md5 = self.unpack_info(header_dic['info_size'])
if header_dic.get('md5'):
# 校检请求内容md5一致性
if info_md5 == header_dic['md5']:
pass
# print('\033[42;1m请求内容md5校检结果一致\033[0m')
else:
pass
# print('\033[31;1m请求内容md5校检结果不一致\033[0m')
return header_dic, req_dic def response(self, **kwargs):
rsp_info = kwargs
rsp_bytes = commons.getDictBytes(rsp_info)
md5 = commons.getStrsMd5(rsp_bytes)
header_size_pack, header_bytes = commons.make_header(info_size=len(rsp_bytes), md5=md5)
self.request.sendall(header_size_pack)
self.request.sendall(header_bytes)
self.request.sendall(rsp_bytes) def register(self, header_dic, req_dic): # {'cmd':'register','info_size':0,'resultCode':0,'resultDesc':None}
username = req_dic['user_info']['username']
# 更新数据库,并制作响应信息字典
if not os.path.isfile(os.path.join(settings.db_file, '%s.json' % username)):
# 更新数据库
user_info = dict()
user_info['username'] = username
user_info['password'] = req_dic['user_info']['password']
user_info['home'] = os.path.join(settings.user_home_dir, username)
user_info['quota'] = settings.user_quota * (1024 * 1024)
commons.save_to_file(user_info, os.path.join(settings.db_file, '%s.json' % username))
resultCode = 0
resultDesc = None
# 创建家目录
if not os.path.exists(os.path.join(settings.user_home_dir, username)):
os.mkdir(os.path.join(settings.user_home_dir, username))
self.server_logger.info('client[%s:%s] 注册用户[%s]成功' % (self.client_address[0], self.client_address[1], username))
else:
resultCode = 1
resultDesc = '该用户已存在,注册失败'
self.server_logger.warning('client[%s:%s] 注册用户[%s]失败:%s' % (self.client_address[0], self.client_address[1],
username, resultDesc))
# 响应客户端注册请求
self.response(resultCode=resultCode, resultDesc=resultDesc) @staticmethod
def auth(req_dic):
# print(req_dic['user_info'])
user_info = None
status_code = 201
try:
req_username = req_dic['user_info']['username']
db_file = os.path.join(settings.db_file, '%s.json' % req_username)
# 验证用户名密码,并制作响应信息字典
if not os.path.isfile(db_file):
status_code = 202
else:
with open(db_file, 'r') as f:
user_info_db = json.load(f)
if user_info_db['password'] == req_dic['user_info']['password']:
status_code = 200
user_info = user_info_db
return status_code, user_info
# 捕获 客户端鉴权请求时发送一个空字典或错误的字典 的异常
except KeyError:
return 201, user_info def login(self, header_dic, req_dic):
# 鉴权
status_code, user_info = self.auth(req_dic)
# 响应客户端登陆请求
self.response(user_info=user_info, resultCode=status_code) def query_quota(self, header_dic, req_dic):
used_quota = None
total_quota = None
# 鉴权
status_code, user_info = self.auth(req_dic)
# 查询配额
if status_code == 200:
used_quota = commons.getFileSize(user_info['home'])
total_quota = user_info['quota']
# 响应客户端配额查询请求
self.response(resultCode=status_code, total_quota=total_quota, used_quota=used_quota) @staticmethod
def parse_file_path(req_path, cur_path):
req_path = req_path.replace(r'/', '\\')
req_path = req_path.replace(r'//', r'/', )
req_path = req_path.replace('\\\\', '\\')
req_path = req_path.replace('~\\', '', 1)
req_path = req_path.replace(r'~', '', 1)
req_paths = re.findall(r'[^\\]+', req_path)
cur_paths = re.findall(r'[^\\]+', cur_path)
cur_paths.extend(req_paths)
cur_paths[0] += '\\'
while '.' in cur_paths:
cur_paths.remove('.')
while '..' in cur_paths:
for index, item in enumerate(cur_paths):
if item == '..':
cur_paths.pop(index)
cur_paths.pop(index - 1)
break
return cur_paths def cd(self, header_dic, req_dic):
cmds = header_dic['cmd'].split()
# 鉴权
status_code, user_info = self.auth(req_dic)
home = os.path.join(settings.user_home_dir, user_info['username'])
# 先定义响应信息
returnPath = req_dic['user_info']['cur_path']
if status_code == 200:
if len(cmds) != 1:
# 解析cd的真实路径
cur_path = os.path.join(settings.user_home_dir, req_dic['user_info']['cur_path'])
cd_path = os.path.join('', *self.parse_file_path(cmds[1], cur_path))
print('cd解析后的路径:', cd_path)
if os.path.isdir(cd_path):
if home in cd_path:
resultCode = 300
returnPath = cd_path.replace('%s\\' % settings.user_home_dir, '', 1)
else:
resultCode = 302
else:
resultCode = 302
else:
resultCode = 301
else:
resultCode = 201
# 响应客户端的cd命令结果
print('cd发送给客户端的路径:', returnPath)
self.response(resultCode=resultCode, returnPath=returnPath) def ls(self, header_dic, req_dic):
cmds = header_dic['cmd'].split()
# 鉴权
status_code, user_info = self.auth(req_dic)
home = os.path.join(settings.user_home_dir, user_info['username'])
# 先定义响应信息
returnFilenames = None
if status_code == 200:
if len(cmds) <= 2:
# 解析ls的真实路径
cur_path = os.path.join(settings.user_home_dir, req_dic['user_info']['cur_path'])
if len(cmds) == 2:
ls_path = os.path.join('', *self.parse_file_path(cmds[1], cur_path))
else:
ls_path = cur_path
print('ls解析后的路径:', ls_path)
if os.path.isdir(ls_path):
if home in ls_path:
returnCode, filenames = commons.getFile(ls_path, home)
resultCode = 300
returnFilenames = filenames
else:
resultCode = 302
else:
resultCode = 302
else:
resultCode = 301
else:
resultCode = 201
# 响应客户端的ls命令结果
self.response(resultCode=resultCode, returnFilenames=returnFilenames) def rm(self, header_dic, req_dic):
cmds = header_dic['cmd'].split()
# 鉴权
status_code, user_info = self.auth(req_dic)
home = os.path.join(settings.user_home_dir, user_info['username'])
# 先定义响应信息
if status_code == 200:
if len(cmds) == 2:
# 解析rm的真实路径
cur_path = os.path.join(settings.user_home_dir, req_dic['user_info']['cur_path'])
rm_path = os.path.join('', *self.parse_file_path(os.path.dirname(cmds[1]), cur_path))
rm_file = os.path.join(rm_path, os.path.basename(cmds[1]))
print('rm解析后的文件或文件夹:', rm_file)
if os.path.exists(rm_file):
if home in rm_file:
commons.rmdirs(rm_file)
resultCode = 300
else:
resultCode = 302
else:
resultCode = 302
else:
resultCode = 301
else:
resultCode = 201
# 响应客户端的rm命令结果
self.response(resultCode=resultCode) def mkdir(self, header_dic, req_dic):
cmds = header_dic['cmd'].split()
# 鉴权
status_code, user_info = self.auth(req_dic)
home = os.path.join(settings.user_home_dir, user_info['username'])
# 先定义响应信息
if status_code == 200:
if len(cmds) == 2:
# 解析rm的真实路径
cur_path = os.path.join(settings.user_home_dir, req_dic['user_info']['cur_path'])
mkdir_path = os.path.join('', *self.parse_file_path(cmds[1], cur_path))
print('mkdir解析后的文件夹:', mkdir_path)
if not os.path.isdir(mkdir_path):
if home in mkdir_path:
os.makedirs(mkdir_path)
resultCode = 300
else:
resultCode = 302
else:
resultCode = 303
else:
resultCode = 301
else:
resultCode = 201
# 响应客户端的mkdir命令结果
self.response(resultCode=resultCode) def get(self, header_dic, req_dic):
"""客户端下载文件"""
cmds = header_dic['cmd'].split() # ['get', 'a.txt', 'download']
get_file = None
# 鉴权
status_code, user_info = self.auth(req_dic)
home = os.path.join(settings.user_home_dir, user_info['username'])
# 解析断点续传信息
position = 0
if req_dic['resume'] and isinstance(req_dic['position'], int):
position = req_dic['position']
# 先定义响应信息
resultCode = 300
FileSize = None
FileMd5 = None
if status_code == 200:
if 1 < len(cmds) < 4:
# 解析需要get文件的真实路径
cur_path = os.path.join(settings.user_home_dir, req_dic['user_info']['cur_path'])
get_file = os.path.join('', *self.parse_file_path(cmds[1], cur_path))
print('get解析后的路径:', get_file)
if os.path.isfile(get_file):
if home in get_file:
FileSize = commons.getFileSize(get_file)
if position >= FileSize != 0:
resultCode = 304
else:
resultCode = 300
FileSize = FileSize
FileMd5 = commons.getFileMd5(get_file)
else:
resultCode = 302
else:
resultCode = 302
else:
resultCode = 301
else:
resultCode = 201
# 响应客户端的get命令结果
self.response(resultCode=resultCode, FileSize=FileSize, FileMd5=FileMd5)
if resultCode == 300:
# 发送文件数据
with open(get_file, 'rb') as f:
f.seek(position)
for line in f:
self.request.send(line) def put(self, header_dic, req_dic):
cmds = header_dic['cmd'].split() # ['put', 'download/a.txt', 'video']
put_file = None
# 鉴权
status_code, user_info = self.auth(req_dic)
home = os.path.join(settings.user_home_dir, user_info['username'])
# 查询配额
used_quota = commons.getFileSize(user_info['home'])
total_quota = user_info['quota']
# 先定义响应信息
if status_code == 200:
if 1 < len(cmds) < 4:
# 解析需要put文件的真实路径
cur_path = os.path.join(settings.user_home_dir, req_dic['user_info']['cur_path'])
if len(cmds) == 3:
put_file = os.path.join(os.path.join('', *self.parse_file_path(cmds[2], cur_path)),
os.path.basename(cmds[1]))
else:
put_file = os.path.join(cur_path, os.path.basename(cmds[1]))
print('put解析后的文件:', put_file)
put_path = os.path.dirname(put_file)
if os.path.isdir(put_path):
if home in put_path:
if (req_dic['FileSize'] + used_quota) <= total_quota:
resultCode = 300
else:
resultCode = 305
else:
resultCode = 302
else:
resultCode = 302
else:
resultCode = 301
else:
resultCode = 201
# 响应客户端的put命令结果
self.response(resultCode=resultCode)
if resultCode == 300:
# 接收文件数据,写入文件
recv_size = 0
with open(put_file, 'wb') as f:
while recv_size < req_dic['FileSize']:
file_data = self.request.recv(1024)
f.write(file_data)
recv_size += len(file_data)
# 校检文件md5一致性
if commons.getFileMd5(put_file) == req_dic['FileMd5']:
resultCode = 400
print('\033[42;1m文件md5校检结果一致\033[0m')
print('\033[42;1m文件上传成功,大小:%d,文件名:%s\033[0m' % (req_dic['FileSize'], put_file))
else:
os.remove(put_file)
resultCode = 401
print('\033[31;1m文件md5校检结果不一致\033[0m')
print('\033[42;1m文件上传失败\033[0m')
# 返回上传文件是否成功响应
self.response(resultCode=resultCode)

server.py

PS:记得把from gevent import monkey;monkey.patch_all() 注释掉

四、seletors模块实现单线程并发效果

1、导入selectors模块

import selectors

2、把accept接受新客户端请求和recv数据分别写成两个函数

def accept(self, server):
"""接受新的client请求"""
request, client_address = server.accept()
request.setblocking(False)
print('client%s is conecting' % ((client_address,)))
self.sel.register(request, selectors.EVENT_READ, self.handle)
self.dic[request] = {} def handle(self, request):
self.request = request
self.client_address = request.getpeername()
try:
if not self.dic[request]:
# 1、接收客户端的ftp命令
header_dic, req_dic = self.recv_request()
if not header_dic or not header_dic['cmd']:
print('client[%s] is disconect' % ((request.getpeername()),))
request.close()
self.sel.unregister(request)
print( 'receive client ftp command:%s' % header_dic['cmd'])
else:
header_dic = self.dic[request]['header_dic']
req_dic = self.dic[request]['req_dic']
# 2、解析ftp命令,获取相应命令参数(文件名)
cmds = header_dic['cmd'].split() # ['register',]、['get', 'a.txt']
if not self.dic[request]:
if hasattr(self, cmds[0]):
getattr(self, cmds[0])(header_dic, req_dic)
else:
getattr(self, cmds[0])(self.dic[request]['header_dic'], self.dic[request]['req_dic'])
except BlockingIOError as e:
pass
except (ConnectionResetError, ConnectionAbortedError, socket.timeout) as e:
print('error: ',e)
print('client[%s] is disconect' % ((request.getpeername()),))
request.close()
self.sel.unregister(request)

3、检测所有的fileobj,是否有完成wait data的

def run(self):
while True:
events = self.sel.select() # 检测所有注册的socket, 是否有完成wait data的
for sel_obj, mask in events:
callback = sel_obj.data # callback = accept
callback(sel_obj.fileobj,)

4、完整代码:

 # -*- coding: utf-8 -*-
# from gevent import monkey;monkey.patch_all()
# import socketserver
import selectors
import socket
import os, json, re, struct, threading, time
# import gevent
# import select
from lib import commons
from conf import settings
from core import logger class Server(object):
def __init__(self):
self.init_dir()
self.server_logger = logger.logger('server')
self.sel = selectors.DefaultSelector()
self.dic = {} # 记录文件传输未完成的状态
self.create_socket()
self.run() @staticmethod
def init_dir():
if not os.path.exists(os.path.join(settings.base_path, 'logs')): os.mkdir(
os.path.join(settings.base_path, 'logs'))
if not os.path.exists(os.path.join(settings.base_path, 'db')): os.mkdir(os.path.join(settings.base_path, 'db'))
if not os.path.exists(os.path.join(settings.base_path, 'home')): os.mkdir(
os.path.join(settings.base_path, 'home')) def create_socket(self):
self.server = socket.socket(family=socket.AF_INET, type=socket.SOCK_STREAM)
self.server.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
self.server.bind((settings.server_bind_ip, settings.server_bind_port))
self.server.listen(settings.server_listen)
self.server.setblocking(False) # 设置为非阻塞
self.sel.register(self.server, selectors.EVENT_READ, self.accept)
print("\033[42;1mserver started sucessful!\033[0m") def accept(self, server):
"""接受新的client请求"""
request, client_address = server.accept()
request.setblocking(False)
print('client%s is conecting' % ((client_address,)))
self.sel.register(request, selectors.EVENT_READ, self.handle)
self.dic[request] = {} def run(self):
while True:
events = self.sel.select() # 检测所有注册的socket, 是否有完成wait data的
for sel_obj, mask in events:
callback = sel_obj.data # callback = accept
callback(sel_obj.fileobj,)
# socketserver模块实现多并发
# server = socketserver.ThreadingTCPServer((settings.server_bind_ip, settings.server_bind_port), TCPHandler)
# server.allow_reuse_address = True
# server.serve_forever()
# while True: # 链接循环
# # select 单进程实现同时处理请求
# inputs = [self.server, ]
# outputs = []
# while True:
# readable, writeable, exceptional = select.select(inputs, outputs, inputs)
# for r in readable:
# if r is self.server:
# request, client_address = self.server.accept()
# inputs.append(request)
# else:
# print('处理request:%s'%id(r))
# return_code, request = TCPHandler().handle(r, )
# if not return_code:
# request.close()
# print('client[%s] is disconect' % ((request.getpeername()),))
# self.request, self.client_address = self.server.accept()
# self.request.settimeout(300)
# # 多线程处理请求
# thread = threading.Thread(target=TCPHandler.handle, args=(TCPHandler(), self.request, self.client_address))
# thread.start()
# # gevent 实现单线程多并发
# # gevent.spawn(TCPHandler.handle, TCPHandler(), self.request, self.client_address)
# class TCPHandler():
STATUS_CODE = {
200: 'Passed authentication!',
201: 'Wrong username or password!',
202: 'Username does not exist!',
300: 'cmd successful , the target path be returned in returnPath',
301: 'cmd format error!',
302: 'The path or file could not be found!',
303: 'The dir is exist',
304: 'The file has been downloaded or the size of the file is exceptions',
305: 'Free space is not enough',
401: 'File MD5 inspection failed',
400: 'File MD5 inspection success',
} # def __init__(self, request, client_address):
# self.server_logger = logger.logger('server')
# self.server_logger.debug("server TCPHandler started successful!")
# self.request = request
# self.client_address = client_address
# super().__init__(request, client_address, server) def handle(self, request):
self.request = request
self.client_address = request.getpeername()
try:
if not self.dic[request]:
# 1、接收客户端的ftp命令
header_dic, req_dic = self.recv_request()
if not header_dic or not header_dic['cmd']:
print('client[%s] is disconect' % ((request.getpeername()),))
request.close()
self.sel.unregister(request)
print( 'receive client ftp command:%s' % header_dic['cmd'])
else:
header_dic = self.dic[request]['header_dic']
req_dic = self.dic[request]['req_dic']
# 2、解析ftp命令,获取相应命令参数(文件名)
cmds = header_dic['cmd'].split() # ['register',]、['get', 'a.txt']
if not self.dic[request]:
if hasattr(self, cmds[0]):
getattr(self, cmds[0])(header_dic, req_dic)
else:
getattr(self, cmds[0])(self.dic[request]['header_dic'], self.dic[request]['req_dic'])
except BlockingIOError as e:
pass
except (ConnectionResetError, ConnectionAbortedError, socket.timeout) as e:
print('error: ',e)
print('client[%s] is disconect' % ((request.getpeername()),))
request.close()
self.sel.unregister(request)
# self.request.close()
# self.server_logger.info('client %s is disconect' % ((self.client_address,)))
# print('client[%s:%s] is disconect' % (self.client_address[0], self.client_address[1])) def unpack_header(self):
try:
pack_obj = self.request.recv(4)
header_size = struct.unpack('i', pack_obj)[0]
time.sleep(1/(10**20))
header_bytes = self.request.recv(header_size)
header_json = header_bytes.decode('utf-8')
header_dic = json.loads(header_json)
return header_dic
except struct.error: # 避免客户端发送错误格式的header_size
return def unpack_info(self, info_size):
recv_size = 0
info_bytes = b''
while recv_size < info_size:
res = self.request.recv(1024)
info_bytes += res
recv_size += len(res)
info_json = info_bytes.decode('utf-8')
info_dic = json.loads(info_json) # {'username':ton, 'password':123}
info_md5 = commons.getStrsMd5(info_bytes)
return info_dic, info_md5 def recv_request(self):
header_dic = self.unpack_header() # {'cmd':'register','info_size':0}
if not header_dic: return None, None
req_dic, info_md5 = self.unpack_info(header_dic['info_size'])
if header_dic.get('md5'):
# 校检请求内容md5一致性
if info_md5 == header_dic['md5']:
pass
# print('\033[42;1m请求内容md5校检结果一致\033[0m')
else:
pass
# print('\033[31;1m请求内容md5校检结果不一致\033[0m')
return header_dic, req_dic def response(self, **kwargs):
rsp_info = kwargs
rsp_bytes = commons.getDictBytes(rsp_info)
md5 = commons.getStrsMd5(rsp_bytes)
header_size_pack, header_bytes = commons.make_header(info_size=len(rsp_bytes), md5=md5)
self.request.sendall(header_size_pack)
self.request.sendall(header_bytes)
self.request.sendall(rsp_bytes) def register(self, header_dic, req_dic): # {'cmd':'register','info_size':0,'resultCode':0,'resultDesc':None}
username = req_dic['user_info']['username']
# 更新数据库,并制作响应信息字典
if not os.path.isfile(os.path.join(settings.db_file, '%s.json' % username)):
# 更新数据库
user_info = dict()
user_info['username'] = username
user_info['password'] = req_dic['user_info']['password']
user_info['home'] = os.path.join(settings.user_home_dir, username)
user_info['quota'] = settings.user_quota * (1024 * 1024)
commons.save_to_file(user_info, os.path.join(settings.db_file, '%s.json' % username))
resultCode = 0
resultDesc = None
# 创建家目录
if not os.path.exists(os.path.join(settings.user_home_dir, username)):
os.mkdir(os.path.join(settings.user_home_dir, username))
self.server_logger.info('client[%s:%s] 注册用户[%s]成功' % (self.client_address[0], self.client_address[1], username))
else:
resultCode = 1
resultDesc = '该用户已存在,注册失败'
self.server_logger.warning('client[%s:%s] 注册用户[%s]失败:%s' % (self.client_address[0], self.client_address[1],
username, resultDesc))
# 响应客户端注册请求
self.response(resultCode=resultCode, resultDesc=resultDesc) @staticmethod
def auth(req_dic):
# print(req_dic['user_info'])
user_info = None
status_code = 201
try:
req_username = req_dic['user_info']['username']
db_file = os.path.join(settings.db_file, '%s.json' % req_username)
# 验证用户名密码,并制作响应信息字典
if not os.path.isfile(db_file):
status_code = 202
else:
with open(db_file, 'r') as f:
user_info_db = json.load(f)
if user_info_db['password'] == req_dic['user_info']['password']:
status_code = 200
user_info = user_info_db
return status_code, user_info
# 捕获 客户端鉴权请求时发送一个空字典或错误的字典 的异常
except KeyError:
return 201, user_info def login(self, header_dic, req_dic):
# 鉴权
status_code, user_info = self.auth(req_dic)
# 响应客户端登陆请求
self.response(user_info=user_info, resultCode=status_code) def query_quota(self, header_dic, req_dic):
used_quota = None
total_quota = None
# 鉴权
status_code, user_info = self.auth(req_dic)
# 查询配额
if status_code == 200:
used_quota = commons.getFileSize(user_info['home'])
total_quota = user_info['quota']
# 响应客户端配额查询请求
self.response(resultCode=status_code, total_quota=total_quota, used_quota=used_quota) @staticmethod
def parse_file_path(req_path, cur_path):
req_path = req_path.replace(r'/', '\\')
req_path = req_path.replace(r'//', r'/', )
req_path = req_path.replace('\\\\', '\\')
req_path = req_path.replace('~\\', '', 1)
req_path = req_path.replace(r'~', '', 1)
req_paths = re.findall(r'[^\\]+', req_path)
cur_paths = re.findall(r'[^\\]+', cur_path)
cur_paths.extend(req_paths)
cur_paths[0] += '\\'
while '.' in cur_paths:
cur_paths.remove('.')
while '..' in cur_paths:
for index, item in enumerate(cur_paths):
if item == '..':
cur_paths.pop(index)
cur_paths.pop(index - 1)
break
return cur_paths def cd(self, header_dic, req_dic):
cmds = header_dic['cmd'].split()
# 鉴权
status_code, user_info = self.auth(req_dic)
home = os.path.join(settings.user_home_dir, user_info['username'])
# 先定义响应信息
returnPath = req_dic['user_info']['cur_path']
if status_code == 200:
if len(cmds) != 1:
# 解析cd的真实路径
cur_path = os.path.join(settings.user_home_dir, req_dic['user_info']['cur_path'])
cd_path = os.path.join('', *self.parse_file_path(cmds[1], cur_path))
print('cd解析后的路径:', cd_path)
if os.path.isdir(cd_path):
if home in cd_path:
resultCode = 300
returnPath = cd_path.replace('%s\\' % settings.user_home_dir, '', 1)
else:
resultCode = 302
else:
resultCode = 302
else:
resultCode = 301
else:
resultCode = 201
# 响应客户端的cd命令结果
print('cd发送给客户端的路径:', returnPath)
self.response(resultCode=resultCode, returnPath=returnPath) def ls(self, header_dic, req_dic):
cmds = header_dic['cmd'].split()
# 鉴权
status_code, user_info = self.auth(req_dic)
home = os.path.join(settings.user_home_dir, user_info['username'])
# 先定义响应信息
returnFilenames = None
if status_code == 200:
if len(cmds) <= 2:
# 解析ls的真实路径
cur_path = os.path.join(settings.user_home_dir, req_dic['user_info']['cur_path'])
if len(cmds) == 2:
ls_path = os.path.join('', *self.parse_file_path(cmds[1], cur_path))
else:
ls_path = cur_path
print('ls解析后的路径:', ls_path)
if os.path.isdir(ls_path):
if home in ls_path:
returnCode, filenames = commons.getFile(ls_path, home)
resultCode = 300
returnFilenames = filenames
else:
resultCode = 302
else:
resultCode = 302
else:
resultCode = 301
else:
resultCode = 201
# 响应客户端的ls命令结果
self.response(resultCode=resultCode, returnFilenames=returnFilenames) def rm(self, header_dic, req_dic):
cmds = header_dic['cmd'].split()
# 鉴权
status_code, user_info = self.auth(req_dic)
home = os.path.join(settings.user_home_dir, user_info['username'])
# 先定义响应信息
if status_code == 200:
if len(cmds) == 2:
# 解析rm的真实路径
cur_path = os.path.join(settings.user_home_dir, req_dic['user_info']['cur_path'])
rm_path = os.path.join('', *self.parse_file_path(os.path.dirname(cmds[1]), cur_path))
rm_file = os.path.join(rm_path, os.path.basename(cmds[1]))
print('rm解析后的文件或文件夹:', rm_file)
if os.path.exists(rm_file):
if home in rm_file:
commons.rmdirs(rm_file)
resultCode = 300
else:
resultCode = 302
else:
resultCode = 302
else:
resultCode = 301
else:
resultCode = 201
# 响应客户端的rm命令结果
self.response(resultCode=resultCode) def mkdir(self, header_dic, req_dic):
cmds = header_dic['cmd'].split()
# 鉴权
status_code, user_info = self.auth(req_dic)
home = os.path.join(settings.user_home_dir, user_info['username'])
# 先定义响应信息
if status_code == 200:
if len(cmds) == 2:
# 解析rm的真实路径
cur_path = os.path.join(settings.user_home_dir, req_dic['user_info']['cur_path'])
mkdir_path = os.path.join('', *self.parse_file_path(cmds[1], cur_path))
print('mkdir解析后的文件夹:', mkdir_path)
if not os.path.isdir(mkdir_path):
if home in mkdir_path:
os.makedirs(mkdir_path)
resultCode = 300
else:
resultCode = 302
else:
resultCode = 303
else:
resultCode = 301
else:
resultCode = 201
# 响应客户端的mkdir命令结果
self.response(resultCode=resultCode) def get(self, header_dic, req_dic):
"""客户端下载文件"""
cmds = header_dic['cmd'].split() # ['get', 'a.txt', 'download']
# 解析需要get文件的真实路径
cur_path = os.path.join(settings.user_home_dir, req_dic['user_info']['cur_path'])
get_file = os.path.join('', *self.parse_file_path(cmds[1], cur_path))
# 鉴权
status_code, user_info = self.auth(req_dic)
home = os.path.join(settings.user_home_dir, user_info['username'])
# 解析断点续传信息
position = 0
if req_dic['resume'] and isinstance(req_dic['position'], int):
position = req_dic['position']
# 先定义响应信息
resultCode = 300
FileSize = None
FileMd5 = None
if status_code == 200:
if 1 < len(cmds) < 4:
print('get解析后的路径:', get_file)
if os.path.isfile(get_file):
if home in get_file:
FileSize = commons.getFileSize(get_file)
if position >= FileSize != 0:
resultCode = 304
else:
resultCode = 300
FileSize = FileSize
FileMd5 = commons.getFileMd5(get_file)
else:
resultCode = 302
else:
resultCode = 302
else:
resultCode = 301
else:
resultCode = 201
# 响应客户端的get命令结果
self.response(resultCode=resultCode, FileSize=FileSize, FileMd5=FileMd5)
if resultCode == 300:
self.request.setblocking(False)
# 发送文件数据
with open(get_file, 'rb') as f:
f.seek(position)
for line in f:
self.request.send(line)
position += len(line) def put(self, header_dic, req_dic):
cmds = header_dic['cmd'].split() # ['put', 'download/a.txt', 'video']
# 解析需要put文件的真实路径
cur_path = os.path.join(settings.user_home_dir, req_dic['user_info']['cur_path'])
if len(cmds) == 3:
put_file = os.path.join(os.path.join('', *self.parse_file_path(cmds[2], cur_path)),
os.path.basename(cmds[1]))
else:
put_file = os.path.join(cur_path, os.path.basename(cmds[1]))
if not self.dic[self.request]:
self.dic[self.request]['action'] = 'put'
self.dic[self.request]['header_dic'] = header_dic
self.dic[self.request]['req_dic'] = req_dic
self.dic[self.request]['position'] = 0
# 鉴权
status_code, user_info = self.auth(req_dic)
home = os.path.join(settings.user_home_dir, user_info['username'])
# 查询配额
used_quota = commons.getFileSize(user_info['home'])
total_quota = user_info['quota']
# 先定义响应信息
if status_code == 200:
if 1 < len(cmds) < 4:
print('put解析后的文件:', put_file)
put_path = os.path.dirname(put_file)
if os.path.isdir(put_path):
if home in put_path:
if (req_dic['FileSize'] + used_quota) <= total_quota:
resultCode = 300
else:
resultCode = 305
else:
resultCode = 302
else:
resultCode = 302
else:
resultCode = 301
else:
resultCode = 201
# 响应客户端的put命令结果
self.response(resultCode=resultCode)
if resultCode == 300:
# 接收文件数据,写入文件
recv_size = 0
with open(put_file, 'wb') as f:
while recv_size < req_dic['FileSize']:
file_data = self.request.recv(1024)
f.write(file_data)
recv_size += len(file_data)
self.dic[self.request]['position'] = recv_size
else:
# 接收文件数据,写入文件
recv_size = self.dic[self.request]['position']
with open(put_file, 'ab') as f:
while recv_size < req_dic['FileSize']:
file_data = self.request.recv(1024)
f.write(file_data)
recv_size += len(file_data)
self.dic[self.request]['position'] = recv_size
# 校检文件md5一致性
if commons.getFileMd5(put_file) == req_dic['FileMd5']:
resultCode = 400
print('\033[42;1m文件md5校检结果一致\033[0m')
print('\033[42;1m文件上传成功,大小:%d,文件名:%s\033[0m' % (req_dic['FileSize'], put_file))
else:
os.remove(put_file)
resultCode = 401
print('\033[31;1m文件md5校检结果不一致\033[0m')
print('\033[42;1m文件上传失败\033[0m')
# 返回上传文件是否成功响应
self.response(resultCode=resultCode)
self.dic[self.request] = {}

server.py