-----------------------code---------------------------
## conda create -n pyannote python=3.8
## conda activate pyannote
## conda install pytorch torchaudio -c pytorch
## pip install https://github.com/pyannote/pyannote-audio/archive/develop.zip
## pip install uvicorn -i https://pypi.tuna.tsinghua.edu.cn/simple
## pip install fastapi -i https://pypi.tuna.tsinghua.edu.cn/simple
## pip install pydub
## pip install python-multipart
## conda install -c pytorch faiss-cpu ps : https://github.com/facebookresearch/faiss/blob/main/INSTALL.md
## pip install cx_Oracle -i https://pypi.tuna.tsinghua.edu.cn/simple
## pip install numpy==1.22 -i https://pypi.tuna.tsinghua.edu.cn/simple
import uvicorn
from fastapi import FastAPI, File, UploadFile
import os
from pyannote.audio import Pipeline
from pyannote.audio import Inference
from fastapi.middleware.cors import CORSMiddleware
import numpy as np
import faiss
import time
import cx_Oracle #cx_Oracle module is imported to provide the API for accessing the Oracle database
import decimal
import json
app = FastAPI(
title='calculate audio feature',
description='we will split audio, get every speeker feature and speek time',
version='1.0.0')
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=False,
allow_methods=["*"],
allow_headers=["*"]
)
pipeline = Pipeline.from_pretrained("pyannote/speaker-segmentation")
HYPER_PARAMETERS = {
# onset/offset activation thresholds
"onset": 0.5, "offset": 0.5,
# remove speech regions shorter than that many seconds.
"min_duration_on": 1.0,
# fill non-speech regions shorter than that many seconds.
"min_duration_off": 2.0
}
pipeline.instantiate(HYPER_PARAMETERS)
inference = Inference("pyannote/embedding", window="whole", device="cpu")
vec_dim = 512
## audio file index
global last_id
last_id = 0
index = faiss.IndexFlatL2(vec_dim)
faiss_index = faiss.IndexIDMap(index)
## user voice index
global last_id2
last_id2 = 0
index2 = faiss.IndexFlatL2(vec_dim)
faiss_index2 = faiss.IndexIDMap(index2)
def OutConverter(value):
if value is None:
return ''
return value
def OutputTypeHandler(cursor, name, defaultType, size, precision, scale):
if defaultType in (cx_Oracle.DB_TYPE_VARCHAR, cx_Oracle.DB_TYPE_CHAR):
return cursor.var(str, size, arraysize = cursor.arraysize, outconverter=OutConverter)
if defaultType == cx_Oracle.CLOB:
return cursor.var(cx_Oracle.LONG_STRING, arraysize = cursor.arraysize,outconverter=OutConverter)
if defaultType == cx_Oracle.BLOB:
return cursor.var(cx_Oracle.LONG_BINARY, arraysize = cursor.arraysize,outconverter=OutConverter)
if defaultType == cx_Oracle.NUMBER:
return cursor.var(decimal.Decimal, arraysize=cursor.arraysize, outconverter=OutConverter)
if defaultType == cx_Oracle.TIMESTAMP:
return cursor.var(str, arraysize = cursor.arraysize,outconverter=OutConverter)
cx_Oracle.init_oracle_client(lib_dir="D:\lixl\instantclient_21_6")
dsn = cx_Oracle.makedsn("hzzzlc.tpddns.cn", "21521", "XE")
connection = cx_Oracle.connect("smhd", "Zzlc_1qaz789", dsn)
connection.outputtypehandler = OutputTypeHandler
cursor = connection.cursor()
cursor.arraysize = 1000
cursor.prefetchrows = 1000
@app.post("/audio/split")
async def speaker_segmentation(file: UploadFile = File(...)):
try:
path = await writeFile(file)
elements = []
diarization = pipeline(path)
print(diarization)
for turn, param, speaker in diarization.itertracks(yield_label=True):
element = []
start = turn.start
end = turn.end
element.append(start)
element.append(end)
embedding = inference.crop(path, turn)
element.append(embedding.tolist())
elements.append(element)
responseData = {'code': 1, 'message': 'success', 'data': elements}
except Exception as e:
print(str(e))
message = 'failed:[' + str(e) + '] '
responseData = {'code': 0, 'message': message}
finally:
os.remove(path)
return responseData
@app.post("/audio/feature/calc")
async def get_audio_embedding(file: UploadFile = File(...)):
try:
path = await writeFile(file)
embedding = inference(path).tolist()
responseData = {'code': 1, 'message': 'success', 'data': embedding}
except Exception as e:
message = 'failed:[' + str(e) + '] '
responseData = {'code': 0, 'message': message}
finally:
os.remove(path)
return responseData
@app.post("/audio/feature/search_by_feature")
async def search_by_feature(feature):
try:
f = json.loads(feature)
features = []
features.append(f)
query_vectors = np.array(features).astype("float32")
topk = 10
res_distance, res_index = faiss_index.search(query_vectors, topk)
responseData = {'code': 1, 'message': 'success', 'data': score(res_distance, res_index)}
except Exception as e:
message = 'failed:[' + str(e) + '] '
responseData = {'code': 0, 'message': message}
return responseData
@app.post("/audio/feature/search_by_audio")
async def search_by_audio(file: UploadFile = File(...)):
try:
path = await writeFile(file)
embedding = inference(path)
features = []
features.append(embedding)
query_vectors = np.array(features).astype("float32")
topk = 10
res_distance, res_index = faiss_index.search(query_vectors, topk)
responseData = {'code': 1, 'message': 'success', 'data': score(res_distance, res_index)}
except Exception as e:
message = 'failed:[' + str(e) + '] '
responseData = {'code': 0, 'message': message}
finally:
os.remove(path)
return responseData
@app.post("/voice/feature/search_by_feature")
async def search_by_feature2(feature):
try:
f = json.loads(feature)
features = []
features.append(f)
query_vectors = np.array(features).astype("float32")
topk = 3
res_distance, res_index = faiss_index2.search(query_vectors, topk)
responseData = {'code': 1, 'message': 'success', 'data': score(res_distance, res_index)}
except Exception as e:
message = 'failed:[' + str(e) + '] '
responseData = {'code': 0, 'message': message}
return responseData
@app.post("/voice/feature/search_by_audio")
async def search_by_audio2(file: UploadFile = File(...)):
try:
path = await writeFile(file)
embedding = inference(path)
features = []
features.append(embedding)
query_vectors = np.array(features).astype("float32")
topk = 3
res_distance, res_index = faiss_index2.search(query_vectors, topk)
responseData = {'code': 1, 'message': 'success', 'data': score(res_distance, res_index)}
except Exception as e:
message = 'failed:[' + str(e) + '] '
responseData = {'code': 0, 'message': message}
finally:
os.remove(path)
return responseData
@app.post("/audio/feature/load")
async def audioLoad():
try:
global last_id
build()
responseData = {'code': 0, 'message': "success, last id:"+str(last_id)}
except Exception as e:
message = 'failed:[' + str(e) + '] '
responseData = {'code': 0, 'message': message}
return responseData
@app.post("/voice/feature/load")
async def voiceLoad():
try:
global last_id2
build2()
responseData = {'code': 0, 'message': "success, last id:"+str(last_id2)}
except Exception as e:
message = 'failed:[' + str(e) + '] '
responseData = {'code': 0, 'message': message}
return responseData
async def writeFile(file):
path =os.path.abspath(os.curdir) + '/'+ str(time.time()) + '-' + file.filename;
contents = await file.read()
with open(path, "wb") as f:
f.write(contents)
return path
def build():
global last_id
sql_statement = 'SELECt ID, FEATURE from ZZLC$AUDIO_FEATURE WHERe ID > '+str(last_id)+' ORDER BY ID ASC'
datas = oracle_query(sql_statement)
ids = []
features = []
if(len(datas)>0):
for data in datas:
id = int(data['ID'])
feature = data['FEATURE']
f = json.loads(feature)
ids.append(id)
features.append(f)
last_id = id
vectors = np.array(features).astype("float32")
idss = np.array(ids).astype('int64')
faiss_index.add_with_ids(vectors, idss)
def build2():
global last_id2
sql_statement = 'SELECt ID, FEATURE from ZZLC$FEATURES_VOICE WHERe ID > ' + str(last_id2) + ' ORDER BY ID ASC'
datas = oracle_query(sql_statement)
ids = []
features = []
if (len(datas) > 0):
for data in datas:
id = int(data['ID'])
feature = data['FEATURE']
f = json.loads(feature)
ids.append(id)
features.append(f)
last_id2 = id
vectors = np.array(features).astype("float32")
idss = np.array(ids).astype('int64')
faiss_index2.add_with_ids(vectors, idss)
def oracle_query(sql_statement):
cursor.execute(sql_statement)
columns = [col[0] for col in cursor.description]
cursor.rowfactory = lambda *args: dict(zip(columns, args))
datas = cursor.fetchall()
return datas
def score(res_distance, res_index):
ids = res_index[0].tolist()
dins = res_distance[0].tolist()
result = []
i = 0
while i < len(ids) :
id = ids[i]
din = dins[i]
score = format(25 * (4 - din/1000000), '.2f')
result.append({"id": id,"score": score})
i = i+1
return result
if __name__ == '__main__':
build()
build2()
uvicorn.run(app=app, host="0.0.0.0", port=6060, workers=1)
-----------------------------------------------------------------------
ps:接口测试方式 http://localhost:6060/docs



