import ZeroClient, { Logger } from './ZeroClient'; import * as starry from '../../src/starry'; import PyProcessor from './PyProcessor'; import { destructPromise } from './async-queue'; import { getPort } from 'portfinder'; import util from 'util'; import { Options } from 'python-shell'; const getPortPromise = util.promisify(getPort); export interface LayoutResult { detection: starry.PageLayout; theta: number; interval: number; sourceSize?: { width: number; height: number; }; } export interface PredictorInterface { layout: (streams: Buffer[]) => LayoutResult[]; layout$reinforce: (streams: Buffer[], baseLayouts: LayoutResult[]) => LayoutResult[]; gauge: (streams: Buffer[]) => { image: Buffer; }[]; mask: (streams: Buffer[]) => { image: Buffer; }[]; semantic: (streams: Buffer[]) => any[]; textLoc: (streams: Buffer[]) => any[]; textOcr: (params: { buffers: Buffer[]; location: any[] }) => any[]; brackets: (params: { buffers: Buffer[] }) => any[]; topo: (params: { clusters: starry.EventCluster[] }) => any[]; gaugeRenderer: (params: [Buffer, Buffer, number]) => { buffer: Buffer; size: { width: number; height: number } }; jianpu: (params: { buffers: Buffer[] }) => any[]; // [source: Buffer, gauge: Buffer, baseY: number] } type PredictorType = keyof PredictorInterface; export type PyClientsConstructOptions = Partial>; export class PyClients { clients = new Map>(); constructor(public readonly options: PyClientsConstructOptions, public readonly logger: Logger = console) {} async getClient(type: PredictorType) { if (this.clients.has(type)) { return this.clients.get(type); } const [promise, resolve, reject] = destructPromise(); const opt = this.options[type]; if (!opt) { throw new Error(`no config for client \`${type}\` found`); } try { if (typeof opt === 'string') { const client = new ZeroClient(); client.bind(opt); resolve(client); } else { const { scriptPath, ...option } = opt; const client = new PyProcessor(scriptPath, option, this.logger); await client.bind(`${await getPortPromise()}`); resolve(client); } this.logger.info(`PyClients: ${type} started`); } catch (err) { this.logger.error(`PyClients: ${type} start fail: ${JSON.stringify(err)}`); reject(err); } this.clients.set(type, promise); return promise; } async checkHost(type: PredictorType): Promise { const client = await this.getClient(type); return client.request('checkHost'); } async warmup() { const opts = Object.keys(this.options) as PredictorType[]; await Promise.all(opts.map((type) => this.getClient(type))); } /** * 模型预测 * @param type layout | mask | gauge | semantic * @param args */ async predictScoreImages(type: T, ...args: Parameters): Promise> { const clientType = type.split('$')[0] as PredictorType; const client = await this.getClient(clientType); let res = null; this.logger.info(`[predictor]: ${type} py start..`); const start = Date.now(); switch (type) { case 'layout': res = await client.request('predictDetection', args); break; case 'layout$reinforce': res = await client.request('predictReinforce', args); break; case 'gauge': case 'mask': res = await client.request('predict', args, { by_buffer: true }); break; case 'semantic': case 'textLoc': res = await client.request('predict', args); break; case 'textOcr': case 'brackets': case 'topo': case 'gaugeRenderer': case 'jianpu': res = await client.request('predict', ...args); break; default: this.logger.error(`[predictor]: no predictor ${type}`); } this.logger.info(`[predictor]: ${type} py duration: ${Date.now() - start}ms`); return res; } }