简介:本文详细阐述SpringBoot如何调用PyTorch语音识别模型,并结合Java音频库实现语音播放功能,覆盖模型导出、服务集成、接口设计和性能优化全流程。
传统语音识别系统依赖Kaldi等工具链,存在部署复杂、模型更新困难等问题。PyTorch凭借动态计算图和丰富的预训练模型(如Wav2Vec2.0、Conformer),成为深度学习语音识别的首选框架。SpringBoot作为企业级Java框架,其RESTful接口和微服务架构特性,天然适合构建语音处理服务。
采用分层架构设计:
import torchfrom transformers import Wav2Vec2ForCTC, Wav2Vec2Processormodel = Wav2Vec2ForCTC.from_pretrained("facebook/wav2vec2-base-960h")processor = Wav2Vec2Processor.from_pretrained("facebook/wav2vec2-base-960h")# 转换为TorchScript格式traced_model = torch.jit.trace(model, (torch.randn(1, 16000),))traced_model.save("wav2vec2_jit.pt")
关键点:
// 使用ProcessBuilder启动Python脚本ProcessBuilder pb = new ProcessBuilder("python", "inference.py", audioPath);Process process = pb.start();BufferedReader reader = new BufferedReader(new InputStreamReader(process.getInputStream()));String result = reader.readLine();
定义proto文件:
service SpeechService {rpc Recognize (AudioRequest) returns (TextResponse);}message AudioRequest {bytes audio_data = 1;int32 sample_rate = 2;}
Python端实现:
import grpcfrom concurrent import futuresimport speech_service_pb2import speech_service_pb2_grpcclass SpeechServicer(speech_service_pb2_grpc.SpeechServiceServicer):def Recognize(self, request, context):inputs = processor(request.audio_data,sampling_rate=request.sample_rate,return_tensors="pt")with torch.no_grad():logits = model(inputs.input_values).logitspredicted_ids = torch.argmax(logits, dim=-1)transcription = processor.decode(predicted_ids[0])return speech_service_pb2.TextResponse(text=transcription)
@PostMapping("/upload")public ResponseEntity<String> uploadAudio(@RequestParam("file") MultipartFile file) {try {// 验证音频格式if (!file.getContentType().equals("audio/wav")) {return ResponseEntity.badRequest().body("仅支持WAV格式");}// 保存临时文件Path tempFile = Files.createTempFile("audio", ".wav");Files.write(tempFile, file.getBytes());// 调用识别服务String result = speechRecognizer.recognize(tempFile);return ResponseEntity.ok(result);} catch (IOException e) {return ResponseEntity.internalServerError().build();}}
public void playAudio(byte[] audioData, int sampleRate) throws UnsupportedAudioFileException,IOException, LineUnavailableException {AudioInputStream ais = new AudioInputStream(new ByteArrayInputStream(audioData),new AudioFormat(sampleRate, 16, 1, true, false),audioData.length / 2);DataLine.Info info = new DataLine.Info(Clip.class, ais.getFormat());Clip clip = (Clip) AudioSystem.getLine(info);clip.open(ais);clip.start();}
@GetMapping("/stream")public ResponseEntity<StreamingResponseBody> streamAudio() {StreamingResponseBody responseBody = outputStream -> {// 模拟实时音频流byte[] buffer = new byte[1024];for (int i = 0; i < 100; i++) {// 生成或获取音频数据Arrays.fill(buffer, (byte) (i % 256));outputStream.write(buffer);outputStream.flush();Thread.sleep(100);}};HttpHeaders headers = new HttpHeaders();headers.set(HttpHeaders.CONTENT_TYPE, "audio/wav");return ResponseEntity.ok().headers(headers).body(responseBody);}
quantized_model = torch.quantization.quantize_dynamic(model, {torch.nn.Linear}, dtype=torch.qint8)
@Asyncpublic CompletableFuture<String> asyncRecognize(Path audioPath) {// 调用识别服务return CompletableFuture.completedFuture(result);}
<!-- pom.xml 关键依赖 --><dependency><groupId>org.springframework.boot</groupId><artifactId>spring-boot-starter-web</artifactId></dependency><dependency><groupId>org.bytedeco</groupId><artifactId>javacv-platform</artifactId><version>1.5.7</version></dependency><dependency><groupId>io.grpc</groupId><artifactId>grpc-netty-shaded</artifactId><version>1.45.1</version></dependency>
@RestController@RequestMapping("/api/speech")public class SpeechController {@Autowiredprivate SpeechRecognizer speechRecognizer;@Autowiredprivate AudioPlayer audioPlayer;@PostMapping("/recognize")public ResponseEntity<SpeechResult> recognize(@RequestParam("file") MultipartFile file,@RequestParam(defaultValue = "false") boolean async) {if (async) {CompletableFuture<SpeechResult> future = CompletableFuture.supplyAsync(() -> {String text = speechRecognizer.recognize(file);return new SpeechResult(text, LocalDateTime.now());});return ResponseEntity.accepted().body(null);} else {String text = speechRecognizer.recognize(file);return ResponseEntity.ok(new SpeechResult(text, LocalDateTime.now()));}}@GetMapping("/play/{text}")public void playText(@PathVariable String text) throws Exception {byte[] audioData = textToSpeechService.convertToAudio(text);audioPlayer.play(audioData, 16000);}}
FROM openjdk:11-jre-slimCOPY target/speech-service.jar /app.jarCOPY models/ /models/CMD ["java", "-jar", "/app.jar"]
音频长度不匹配:
识别准确率低:
实时性不足:
多语言支持:
本文通过完整的代码示例和架构设计,展示了从PyTorch模型部署到SpringBoot服务集成的全流程实现。实际开发中,建议先实现核心识别功能,再逐步扩展播放、流式处理等高级特性。对于生产环境部署,需特别注意异常处理、资源隔离和性能监控等关键环节。