starry / backend /libs /predictors.ts
k-l-lambda's picture
feat: add Python ML services (CPU mode) with model download
2b7aae2
import ZeroClient, { Logger } from './ZeroClient';
import * as starry from '../../src/starry';
import PyProcessor from './PyProcessor';
import { destructPromise } from './async-queue';
import { getPort } from 'portfinder';
import util from 'util';
import { Options } from 'python-shell';
const getPortPromise = util.promisify(getPort);
export interface LayoutResult {
detection: starry.PageLayout;
theta: number;
interval: number;
sourceSize?: {
width: number;
height: number;
};
}
export interface PredictorInterface {
layout: (streams: Buffer[]) => LayoutResult[];
layout$reinforce: (streams: Buffer[], baseLayouts: LayoutResult[]) => LayoutResult[];
gauge: (streams: Buffer[]) => {
image: Buffer;
}[];
mask: (streams: Buffer[]) => {
image: Buffer;
}[];
semantic: (streams: Buffer[]) => any[];
textLoc: (streams: Buffer[]) => any[];
textOcr: (params: { buffers: Buffer[]; location: any[] }) => any[];
brackets: (params: { buffers: Buffer[] }) => any[];
topo: (params: { clusters: starry.EventCluster[] }) => any[];
gaugeRenderer: (params: [Buffer, Buffer, number]) => { buffer: Buffer; size: { width: number; height: number } };
jianpu: (params: { buffers: Buffer[] }) => any[];
// [source: Buffer, gauge: Buffer, baseY: number]
}
type PredictorType = keyof PredictorInterface;
export type PyClientsConstructOptions = Partial<Record<PredictorType, Options | string>>;
export class PyClients {
clients = new Map<string, Promise<ZeroClient>>();
constructor(public readonly options: PyClientsConstructOptions, public readonly logger: Logger = console) {}
async getClient(type: PredictorType) {
if (this.clients.has(type)) {
return this.clients.get(type);
}
const [promise, resolve, reject] = destructPromise<ZeroClient>();
const opt = this.options[type];
if (!opt) {
throw new Error(`no config for client \`${type}\` found`);
}
try {
if (typeof opt === 'string') {
const client = new ZeroClient();
client.bind(opt);
resolve(client);
} else {
const { scriptPath, ...option } = opt;
const client = new PyProcessor(scriptPath, option, this.logger);
await client.bind(`${await getPortPromise()}`);
resolve(client);
}
this.logger.info(`PyClients: ${type} started`);
} catch (err) {
this.logger.error(`PyClients: ${type} start fail: ${JSON.stringify(err)}`);
reject(err);
}
this.clients.set(type, promise);
return promise;
}
async checkHost(type: PredictorType): Promise<string> {
const client = await this.getClient(type);
return client.request('checkHost');
}
async warmup() {
const opts = Object.keys(this.options) as PredictorType[];
await Promise.all(opts.map((type) => this.getClient(type)));
}
/**
* 模型预测
* @param type layout | mask | gauge | semantic
* @param args
*/
async predictScoreImages<T extends PredictorType>(type: T, ...args: Parameters<PredictorInterface[T]>): Promise<ReturnType<PredictorInterface[T]>> {
const clientType = type.split('$')[0] as PredictorType;
const client = await this.getClient(clientType);
let res = null;
this.logger.info(`[predictor]: ${type} py start..`);
const start = Date.now();
switch (type) {
case 'layout':
res = await client.request('predictDetection', args);
break;
case 'layout$reinforce':
res = await client.request('predictReinforce', args);
break;
case 'gauge':
case 'mask':
res = await client.request('predict', args, { by_buffer: true });
break;
case 'semantic':
case 'textLoc':
res = await client.request('predict', args);
break;
case 'textOcr':
case 'brackets':
case 'topo':
case 'gaugeRenderer':
case 'jianpu':
res = await client.request('predict', ...args);
break;
default:
this.logger.error(`[predictor]: no predictor ${type}`);
}
this.logger.info(`[predictor]: ${type} py duration: ${Date.now() - start}ms`);
return res;
}
}