Spaces:
Running
Running
| import ZeroClient, { Logger } from './ZeroClient'; | |
| import * as starry from '../../src/starry'; | |
| import PyProcessor from './PyProcessor'; | |
| import { destructPromise } from './async-queue'; | |
| import { getPort } from 'portfinder'; | |
| import util from 'util'; | |
| import { Options } from 'python-shell'; | |
| const getPortPromise = util.promisify(getPort); | |
| export interface LayoutResult { | |
| detection: starry.PageLayout; | |
| theta: number; | |
| interval: number; | |
| sourceSize?: { | |
| width: number; | |
| height: number; | |
| }; | |
| } | |
| export interface PredictorInterface { | |
| layout: (streams: Buffer[]) => LayoutResult[]; | |
| layout$reinforce: (streams: Buffer[], baseLayouts: LayoutResult[]) => LayoutResult[]; | |
| gauge: (streams: Buffer[]) => { | |
| image: Buffer; | |
| }[]; | |
| mask: (streams: Buffer[]) => { | |
| image: Buffer; | |
| }[]; | |
| semantic: (streams: Buffer[]) => any[]; | |
| textLoc: (streams: Buffer[]) => any[]; | |
| textOcr: (params: { buffers: Buffer[]; location: any[] }) => any[]; | |
| brackets: (params: { buffers: Buffer[] }) => any[]; | |
| topo: (params: { clusters: starry.EventCluster[] }) => any[]; | |
| gaugeRenderer: (params: [Buffer, Buffer, number]) => { buffer: Buffer; size: { width: number; height: number } }; | |
| jianpu: (params: { buffers: Buffer[] }) => any[]; | |
| // [source: Buffer, gauge: Buffer, baseY: number] | |
| } | |
| type PredictorType = keyof PredictorInterface; | |
| export type PyClientsConstructOptions = Partial<Record<PredictorType, Options | string>>; | |
| export class PyClients { | |
| clients = new Map<string, Promise<ZeroClient>>(); | |
| constructor(public readonly options: PyClientsConstructOptions, public readonly logger: Logger = console) {} | |
| async getClient(type: PredictorType) { | |
| if (this.clients.has(type)) { | |
| return this.clients.get(type); | |
| } | |
| const [promise, resolve, reject] = destructPromise<ZeroClient>(); | |
| const opt = this.options[type]; | |
| if (!opt) { | |
| throw new Error(`no config for client \`${type}\` found`); | |
| } | |
| try { | |
| if (typeof opt === 'string') { | |
| const client = new ZeroClient(); | |
| client.bind(opt); | |
| resolve(client); | |
| } else { | |
| const { scriptPath, ...option } = opt; | |
| const client = new PyProcessor(scriptPath, option, this.logger); | |
| await client.bind(`${await getPortPromise()}`); | |
| resolve(client); | |
| } | |
| this.logger.info(`PyClients: ${type} started`); | |
| } catch (err) { | |
| this.logger.error(`PyClients: ${type} start fail: ${JSON.stringify(err)}`); | |
| reject(err); | |
| } | |
| this.clients.set(type, promise); | |
| return promise; | |
| } | |
| async checkHost(type: PredictorType): Promise<string> { | |
| const client = await this.getClient(type); | |
| return client.request('checkHost'); | |
| } | |
| async warmup() { | |
| const opts = Object.keys(this.options) as PredictorType[]; | |
| await Promise.all(opts.map((type) => this.getClient(type))); | |
| } | |
| /** | |
| * 模型预测 | |
| * @param type layout | mask | gauge | semantic | |
| * @param args | |
| */ | |
| async predictScoreImages<T extends PredictorType>(type: T, ...args: Parameters<PredictorInterface[T]>): Promise<ReturnType<PredictorInterface[T]>> { | |
| const clientType = type.split('$')[0] as PredictorType; | |
| const client = await this.getClient(clientType); | |
| let res = null; | |
| this.logger.info(`[predictor]: ${type} py start..`); | |
| const start = Date.now(); | |
| switch (type) { | |
| case 'layout': | |
| res = await client.request('predictDetection', args); | |
| break; | |
| case 'layout$reinforce': | |
| res = await client.request('predictReinforce', args); | |
| break; | |
| case 'gauge': | |
| case 'mask': | |
| res = await client.request('predict', args, { by_buffer: true }); | |
| break; | |
| case 'semantic': | |
| case 'textLoc': | |
| res = await client.request('predict', args); | |
| break; | |
| case 'textOcr': | |
| case 'brackets': | |
| case 'topo': | |
| case 'gaugeRenderer': | |
| case 'jianpu': | |
| res = await client.request('predict', ...args); | |
| break; | |
| default: | |
| this.logger.error(`[predictor]: no predictor ${type}`); | |
| } | |
| this.logger.info(`[predictor]: ${type} py duration: ${Date.now() - start}ms`); | |
| return res; | |
| } | |
| } | |