whisper 语音识别
flyfish
Whisper 是一种通用的语音识别模型。它在大量多样化的音频数据集上进行了训练,同时也是一个多任务模型,能够执行多语言语音识别、语音翻译和语言识别。
这是一个基于 Transformer 的序列到序列模型,训练了多种语音处理任务,包括多语言语音识别、语音翻译、口语语言识别和语音活动检测。这些任务被联合表示为一系列由解码器预测的标记,从而使得单个模型可以替代传统语音处理流程中的多个阶段。多任务训练格式使用了一组特殊标记,作为任务指定符或分类目标。
从视频中提取音频
import argparse
from moviepy import VideoFileClip
class AudioExtractor:
def __init__(self, video_path, audio_path):
self.video_path = video_path
self.audio_path = audio_path
def extract(self):
try:
video = VideoFileClip(self.video_path)
audio = video.audio
audio.write_audiofile(self.audio_path)
print("音频提取成功!")
except Exception as e:
print(f"提取音频时出现错误: {e}")
finally:
if 'video' in locals():
video.close()
if 'audio' in locals():
audio.close()
if __name__ == "__main__":
parser = argparse.ArgumentParser(description='从视频中提取音频')
parser.add_argument('--video', default='your_video.mp4', help='视频文件路径,默认为 your_video.mp4')
parser.add_argument('--audio', default='extracted_audio.mp3', help='音频文件保存路径,默认为 extracted_audio.mp3')
args = parser.parse_args()
extractor = AudioExtractor(args.video, args.audio)
extractor.extract()
语音识别
长音频文件的处理方法
对于长音频文件,直接一次性处理可能会导致显存不足或性能下降。以下是一些建议:
-
合理设置
chunk_length_s
和stride_length_s
- 将音频分割成较短的片段(如 30 秒),根据硬件性能和模型能力调整此值。
- 设置适当的重叠时间(如 5 秒),增加重叠时间可以提高转录结果的连贯性,但会增加计算量。
- 示例:
chunk_length_s=30, stride_length_s=5
-
分批处理
- 使用
batch_size
参数来分批处理音频片段,避免一次性加载过多数据。
- 使用
import argparse
from transformers import pipeline
class TranscriberFactory:
"""
用于创建语音识别管道 (transcriber)
"""
@staticmethod
def create_transcriber(
model_path, # 模型路径或名称
device="cuda:0", # 运行设备,"cpu" 或 "cuda:0"(GPU)
chunk_length_s=30, # 每个片段的长度(秒)
stride_length_s=5, # 片段之间的重叠时间(秒)
return_timestamps=True # 是否返回时间戳信息
):
"""
创建并返回一个语音识别管道对象。
:param model_path: 指定模型的路径或名称。
:param device: 指定运行设备,如 "cpu" 或 "cuda:0"(默认使用 GPU)。
:param chunk_length_s: 处理长音频时,每个片段的长度(秒),默认为 30 秒。
:param stride_length_s: 片段之间的重叠时间(秒),用于平滑转录结果,默认为 5 秒。
:param return_timestamps: 是否返回每个片段的时间戳信息,默认为 True。
:return: 返回一个 transcriber 对象,用于执行语音识别任务。
"""
return pipeline(
task="automatic-speech-recognition", # 任务类型:自动语音识别 (ASR)
model=model_path, # 指定使用的模型路径或名称
device=device, # 指定运行设备
framework="pt", # 使用 PyTorch 框架 ("pt")
tokenizer=None, # 分词器,默认使用与模型关联的分词器
feature_extractor=None, # 特征提取器,默认使用与模型关联的特征提取器
model_kwargs=None, # 额外传递给模型初始化的参数
pipeline_class=None, # 自定义管道类,默认使用 Hugging Face 提供的标准管道
chunk_length_s=chunk_length_s, # 每个片段的长度(秒)
stride_length_s=stride_length_s, # 片段之间的重叠时间(秒)
return_timestamps=return_timestamps # 是否返回时间戳信息
)
class ConfigurationManager:
"""
用于管理配置信息
"""
_instance = None # 保存单例实例
def __new__(cls, *args, **kwargs):
"""
确保 ConfigurationManager 只有一个实例。
:return: 返回唯一的实例对象。
"""
if cls._instance is None: # 如果实例不存在,则创建新实例
cls._instance = super().__new__(cls)
return cls._instance # 返回已存在的实例
def __init__(self, model_path, audio_path):
"""
初始化配置管理器。
:param model_path: 模型路径。
:param audio_path: 音频文件路径。
"""
self.model_path = model_path # 模型路径
self.audio_path = audio_path # 音频文件路径
def main():
"""
主函数:解析命令行参数并执行语音识别任务
"""
# 1. 解析命令行参数
parser = argparse.ArgumentParser(description="语音识别工具") # 创建 ArgumentParser 对象
parser.add_argument(
"--model_path", # 参数名
type=str, # 参数类型
default="/home/user/whisper-large-v3-zh/", # 默认值
help="模型路径" # 帮助信息
)
parser.add_argument(
"--audio_path", # 参数名
type=str, # 参数类型
default="1.mp3", # 默认值
help="音频文件路径,默认为 '1.mp3'" # 帮助信息
)
args = parser.parse_args() # 解析命令行参数
# 2. 初始化配置管理器
config_manager = ConfigurationManager(args.model_path, args.audio_path) # 创建配置管理器实例
# 3. 创建语音识别管道
transcriber = TranscriberFactory.create_transcriber(config_manager.model_path) # 使用工厂类创建 transcriber
# 4. 执行语音识别
result = transcriber(
config_manager.audio_path, # 输入音频文件路径
batch_size=1, # 每个批次处理的输入数量,默认为 1
generate_kwargs=None, # 传递给模型 generate 方法的额外参数,默认为 None
max_new_tokens=None # 生成的最大新标记数,默认为 None
)
# 5. 输出结果
print("完整文本:") # 打印完整转录文本
print(result["text"]) # 转录结果的完整文本部分
print("\n带时间戳的片段:") # 打印带时间戳的片段
if "chunks" in result: # 如果结果中包含时间戳信息
for chunk in result["chunks"]: # 遍历每个片段
print(f"时间范围: {chunk['timestamp']} -> 文本: {chunk['text']}") # 打印时间范围和对应文本
if __name__ == "__main__":
main() # 调用主函数
以下是加入了详细中文注释的代码:
import argparse
from transformers import pipeline
class TranscriberFactory:
"""
工厂类:用于创建语音识别管道 (transcriber)
"""
@staticmethod
def create_transcriber(
model_path, # 模型路径或名称
device="cuda:0", # 运行设备,"cpu" 或 "cuda:0"(GPU)
chunk_length_s=30, # 每个片段的长度(秒)
stride_length_s=5, # 片段之间的重叠时间(秒)
return_timestamps=True # 是否返回时间戳信息
):
"""
创建并返回一个语音识别管道对象。
:param model_path: 指定模型的路径或名称。
:param device: 指定运行设备,如 "cpu" 或 "cuda:0"(默认使用 GPU)。
:param chunk_length_s: 处理长音频时,每个片段的长度(秒),默认为 30 秒。
:param stride_length_s: 片段之间的重叠时间(秒),用于平滑转录结果,默认为 5 秒。
:param return_timestamps: 是否返回每个片段的时间戳信息,默认为 True。
:return: 返回一个 transcriber 对象,用于执行语音识别任务。
"""
return pipeline(
task="automatic-speech-recognition", # 任务类型:自动语音识别 (ASR)
model=model_path, # 指定使用的模型路径或名称
device=device, # 指定运行设备
framework="pt", # 使用 PyTorch 框架 ("pt")
tokenizer=None, # 分词器,默认使用与模型关联的分词器
feature_extractor=None, # 特征提取器,默认使用与模型关联的特征提取器
model_kwargs=None, # 额外传递给模型初始化的参数
pipeline_class=None, # 自定义管道类,默认使用 Hugging Face 提供的标准管道
chunk_length_s=chunk_length_s, # 每个片段的长度(秒)
stride_length_s=stride_length_s, # 片段之间的重叠时间(秒)
return_timestamps=return_timestamps # 是否返回时间戳信息
)
class ConfigurationManager:
"""
单例模式:用于管理配置信息
"""
_instance = None # 保存单例实例
def __new__(cls, *args, **kwargs):
"""
确保 ConfigurationManager 只有一个实例。
:return: 返回唯一的实例对象。
"""
if cls._instance is None: # 如果实例不存在,则创建新实例
cls._instance = super().__new__(cls)
return cls._instance # 返回已存在的实例
def __init__(self, model_path, audio_path):
"""
初始化配置管理器。
:param model_path: 模型路径。
:param audio_path: 音频文件路径。
"""
self.model_path = model_path # 模型路径
self.audio_path = audio_path # 音频文件路径
def main():
"""
主函数:解析命令行参数并执行语音识别任务
"""
# 1. 解析命令行参数
parser = argparse.ArgumentParser(description="语音识别工具") # 创建 ArgumentParser 对象
parser.add_argument(
"--model_path", # 参数名
type=str, # 参数类型
default="/home/user/whisper-large-v3-zh/", # 默认值
help="模型路径" # 帮助信息
)
parser.add_argument(
"--audio_path", # 参数名
type=str, # 参数类型
default="1.mp3", # 默认值
help="音频文件路径,默认为 '1.mp3'" # 帮助信息
)
args = parser.parse_args() # 解析命令行参数
# 2. 初始化配置管理器
config_manager = ConfigurationManager(args.model_path, args.audio_path) # 创建配置管理器实例
# 3. 创建语音识别管道
transcriber = TranscriberFactory.create_transcriber(config_manager.model_path) # 使用工厂类创建 transcriber
# 4. 执行语音识别
result = transcriber(
config_manager.audio_path, # 输入音频文件路径
batch_size=1, # 每个批次处理的输入数量,默认为 1
generate_kwargs=None, # 传递给模型 generate 方法的额外参数,默认为 None
max_new_tokens=None # 生成的最大新标记数,默认为 None
)
# 5. 输出结果
print("完整文本:") # 打印完整转录文本
print(result["text"]) # 转录结果的完整文本部分
print("\n带时间戳的片段:") # 打印带时间戳的片段
if "chunks" in result: # 如果结果中包含时间戳信息
for chunk in result["chunks"]: # 遍历每个片段
print(f"时间范围: {chunk['timestamp']} -> 文本: {chunk['text']}") # 打印时间范围和对应文本
if __name__ == "__main__":
main() # 调用主函数
流程:
1. 使用 `argparse` 解析命令行参数。
2. 创建 `ConfigurationManager` 实例,保存模型路径和音频文件路径。
3. 使用 `TranscriberFactory` 创建语音识别管道。
4. 调用 `transcriber` 执行语音识别任务。
5. 输出识别结果,包括完整文本和带时间戳的片段。
运行方式
-
使用默认参数运行:
python speech_recognition.py
-
自定义参数运行:
python speech_recognition.py --model_path /path/to/custom_model --audio_path /path/to/audio.mp3
输出格式
- 完整文本:打印完整的转录结果。
- 带时间戳的片段:如果启用了时间戳功能,打印每个片段的时间范围和对应文本。