Spaces:
Running
Running
| import sha1 from 'sha1'; | |
| import { Canvas, Image, loadImage } from 'skia-canvas'; | |
| import { WeakLRUCache } from 'weak-lru-cache'; | |
| import * as starry from '../../src/starry'; | |
| import { SemanticGraph } from '../../src/starry'; | |
| import { LayoutResult, PyClients } from './predictors'; | |
| import { constructSystem, convertImage } from './util'; | |
| globalThis.OffscreenCanvas = (globalThis as any).OffscreenCanvas || Canvas; | |
| (globalThis as any).Image = (globalThis as any).Image || Image; | |
| globalThis.btoa = globalThis.btoa || ((str: string) => Buffer.from(str, 'binary').toString('base64')); | |
| const STAFF_PADDING_LEFT = 32; | |
| const MAX_PAGE_WIDTH = 1200; | |
| const GAUGE_VISION_SPEC = { | |
| viewportHeight: 256, | |
| viewportUnit: 8, | |
| }; | |
| const MASK_VISION_SPEC = { | |
| viewportHeight: 192, | |
| viewportUnit: 8, | |
| }; | |
| const SEMANTIC_VISION_SPEC = { | |
| viewportHeight: 192, | |
| viewportUnit: 8, | |
| }; | |
| interface OMRStat { | |
| cost: number; // in milliseconds | |
| pagesCost: number; // in milliseconds | |
| pages: number; | |
| } | |
| interface OMRSummary { | |
| costTotal: number; // in milliseconds | |
| costPerPage: number; // in milliseconds | |
| pagesTotal: number; | |
| scoreN: number; | |
| } | |
| /** | |
| * 为布局识别的图片标准化处理 | |
| * @param image | |
| * @param width | |
| */ | |
| function scaleForLayout(image: Image, width: number): Canvas { | |
| let height = (image.height / image.width) * width; | |
| const canvas = new Canvas(width, height); | |
| const ctx = canvas.getContext('2d'); | |
| ctx.drawImage(image, 0, 0, width, (width * image.height) / image.width); | |
| return canvas; | |
| } | |
| /** | |
| * 根据所有图像的检测结果设置合适的全局页面尺寸 | |
| * @param score | |
| * @param detections | |
| * @param outputWidth | |
| */ | |
| function setGlobalPageSize(score: starry.Score, detections: LayoutResult[], outputWidth: number) { | |
| const sizeRatios = detections | |
| .filter((s) => s && s.detection && s.detection.areas?.length) | |
| .map((v, k) => { | |
| const staffInterval = Math.min(...v.detection.areas.filter((area) => area.staves?.middleRhos?.length).map((x) => x.staves.interval)); | |
| const sourceSize = v.sourceSize; | |
| return { | |
| ...v, | |
| index: k, | |
| vw: sourceSize.width / staffInterval, // 页面宽度(逻辑单位) | |
| hwr: sourceSize.height / sourceSize.width, // 页面高宽比 | |
| }; | |
| }); | |
| if (!sizeRatios.length) { | |
| throw new Error('empty result'); | |
| } | |
| const maxVW = sizeRatios.sort((a, b) => b.vw - a.vw)[0]; | |
| const maxAspect = Math.max(...sizeRatios.map((r) => r.hwr)); | |
| score.unitSize = outputWidth / maxVW.vw; | |
| // 页面显示尺寸 | |
| score.pageSize = { | |
| width: outputWidth, | |
| height: outputWidth * maxAspect, | |
| }; | |
| } | |
| const batchTask = (fn: () => Promise<any>) => fn(); | |
| const concurrencyTask = (fns: (() => Promise<any>)[]) => Promise.all(fns.map((fn) => fn())); | |
| const shootStaffImage = async ( | |
| system: starry.System, | |
| staffIndex: number, | |
| { paddingLeft = 0, scaling = 1, spec }: { paddingLeft?: number; scaling?: number; spec: { viewportHeight: number; viewportUnit: number } } | |
| ): Promise<Canvas> => { | |
| if (!system || !system.backgroundImage) return null; | |
| const staff = system.staves[staffIndex]; | |
| if (!staff) return null; | |
| const middleUnits = spec.viewportHeight / spec.viewportUnit / 2; | |
| const width = system.imagePosition.width * spec.viewportUnit; | |
| const height = system.imagePosition.height * spec.viewportUnit; | |
| const x = system.imagePosition.x * spec.viewportUnit + paddingLeft; | |
| const y = (system.imagePosition.y - (staff.top + staff.staffY - middleUnits)) * spec.viewportUnit; | |
| const canvas = new Canvas(Math.round(width + x) * scaling, spec.viewportHeight * scaling); | |
| const context = canvas.getContext('2d'); | |
| context.fillStyle = 'white'; | |
| context.fillRect(0, 0, canvas.width, canvas.height); | |
| context.drawImage(await loadImage(system.backgroundImage), x * scaling, y * scaling, width * scaling, height * scaling); | |
| return canvas; | |
| // .substr(22); // remove the prefix of 'data:image/png;base64,' | |
| }; | |
| /** | |
| * 根据布局检测结果进行截图 | |
| * @param score | |
| * @param pageCanvas | |
| * @param page | |
| * @param detection | |
| */ | |
| async function shootImageByDetection({ | |
| page, | |
| score, | |
| pageCanvas, | |
| }: { | |
| score: starry.Score; | |
| page: starry.Page; | |
| pageCanvas: Canvas; // 原始图片绘制好的canvas | |
| }) { | |
| if (!page?.layout?.areas?.length) { | |
| return null; | |
| } | |
| page.width = score.pageSize.width / score.unitSize; | |
| page.height = score.pageSize.height / score.unitSize; | |
| const correctCanvas = new Canvas(pageCanvas.width, pageCanvas.height); | |
| const ctx = correctCanvas.getContext('2d'); | |
| ctx.save(); | |
| const { width, height } = correctCanvas; | |
| const [a, b, c, d] = page.source.matrix; | |
| ctx.setTransform(a, b, c, d, (-1 / 2) * width + (1 / 2) * a * width + (1 / 2) * b * height, (-1 / 2) * height + (1 / 2) * c * width + (1 / 2) * d * height); | |
| ctx.drawImage(pageCanvas, 0, 0); | |
| ctx.restore(); | |
| const interval = page.source.interval; | |
| page.layout.areas.map((area, systemIndex) => { | |
| console.assert(area.staves?.middleRhos?.length, '[shootImageByDetection] empty area:', area); | |
| const data = ctx.getImageData(area.x, area.y, area.width, area.height); | |
| const canvas = new Canvas(area.width, area.height); | |
| const context = canvas.getContext('2d'); | |
| // context.rotate(-area.staves.theta); | |
| context.putImageData(data, 0, 0); | |
| const detection = area.staves; | |
| const size = { width: area.width, height: area.height }; | |
| const sourceCenter = { | |
| x: pageCanvas.width / 2 / interval, | |
| y: pageCanvas.height / 2 / interval, | |
| }; | |
| const position = { | |
| x: (area.x + area.staves.phi1) / interval - sourceCenter.x + page.width / 2, | |
| y: area.y / interval - sourceCenter.y + page.height / 2, | |
| }; | |
| page.systems[systemIndex] = constructSystem({ | |
| page, | |
| backgroundImage: canvas.toBufferSync('png'), | |
| detection, | |
| imageSize: size, | |
| position, | |
| }); | |
| }); | |
| return correctCanvas; | |
| } | |
| async function shootStaffBackgroundImage({ system, staff, staffIndex }: { system: starry.System; staff: starry.Staff; staffIndex: number }) { | |
| const sourceCanvas = await shootStaffImage(system, staffIndex, { | |
| paddingLeft: STAFF_PADDING_LEFT, | |
| spec: SEMANTIC_VISION_SPEC, | |
| }); | |
| staff.backgroundImage = sourceCanvas.toBufferSync('png'); | |
| staff.imagePosition = { | |
| x: -STAFF_PADDING_LEFT / SEMANTIC_VISION_SPEC.viewportUnit, | |
| y: staff.staffY - SEMANTIC_VISION_SPEC.viewportHeight / 2 / SEMANTIC_VISION_SPEC.viewportUnit, | |
| width: sourceCanvas.width / SEMANTIC_VISION_SPEC.viewportUnit, | |
| height: sourceCanvas.height / SEMANTIC_VISION_SPEC.viewportUnit, | |
| }; | |
| } | |
| /** | |
| * 单个staff的变形矫正 | |
| * @param system | |
| * @param staff | |
| * @param staffIndex | |
| * @param gaugeImage | |
| * @param pyClients | |
| */ | |
| async function gaugeStaff({ | |
| system, | |
| staff, | |
| staffIndex, | |
| gaugeImage, | |
| pyClients, | |
| }: { | |
| system: starry.System; | |
| staff: starry.Staff; | |
| staffIndex: number; | |
| gaugeImage: Buffer; | |
| pyClients: PyClients; | |
| }) { | |
| const sourceCanvas = await shootStaffImage(system, staffIndex, { | |
| paddingLeft: STAFF_PADDING_LEFT, | |
| spec: GAUGE_VISION_SPEC, | |
| scaling: 2, | |
| }); | |
| const sourceBuffer = sourceCanvas.toBufferSync('png'); | |
| const baseY = (system.middleY - (staff.top + staff.staffY)) * GAUGE_VISION_SPEC.viewportUnit + GAUGE_VISION_SPEC.viewportHeight / 2; | |
| const { buffer, size } = await pyClients.predictScoreImages('gaugeRenderer', [sourceBuffer, gaugeImage, baseY]); | |
| staff.backgroundImage = buffer; | |
| staff.imagePosition = { | |
| x: -STAFF_PADDING_LEFT / GAUGE_VISION_SPEC.viewportUnit, | |
| y: staff.staffY - size.height / 2 / GAUGE_VISION_SPEC.viewportUnit, | |
| width: size.width / GAUGE_VISION_SPEC.viewportUnit, | |
| height: size.height / GAUGE_VISION_SPEC.viewportUnit, | |
| }; | |
| staff.maskImage = null; | |
| } | |
| /** | |
| * 单个staff的降噪 | |
| * @param staff | |
| * @param staffIndex | |
| * @param maskImage | |
| */ | |
| async function maskStaff({ staff, staffIndex, maskImage }: { staff: starry.Staff; staffIndex: number; maskImage: Buffer }) { | |
| const img = await loadImage(maskImage); | |
| staff.maskImage = maskImage; | |
| staff.imagePosition = { | |
| x: -STAFF_PADDING_LEFT / MASK_VISION_SPEC.viewportUnit, | |
| y: staff.staffY - MASK_VISION_SPEC.viewportHeight / 2 / MASK_VISION_SPEC.viewportUnit, | |
| width: img.width / MASK_VISION_SPEC.viewportUnit, | |
| height: img.height / MASK_VISION_SPEC.viewportUnit, | |
| }; | |
| } | |
| /** | |
| * 单个staff的语义识别 | |
| * @param score | |
| * @param staffIndex | |
| * @param system | |
| * @param staff | |
| * @param graph | |
| */ | |
| async function semanticStaff({ | |
| score, | |
| staffIndex, | |
| system, | |
| staff, | |
| graph, | |
| }: { | |
| score: starry.Score; | |
| staffIndex: number; | |
| system: starry.System; | |
| staff: starry.Staff; | |
| graph: SemanticGraph; | |
| }) { | |
| graph.offset(-STAFF_PADDING_LEFT / SEMANTIC_VISION_SPEC.viewportUnit, 0); | |
| system.assignSemantics(staffIndex, graph); | |
| staff.assignSemantics(graph); | |
| staff.clearPredictedTokens(); | |
| score.assembleSystem(system, score.settings?.semanticConfidenceThreshold || 1); | |
| } | |
| function replacePageImages(page: starry.Page, onReplaceImageKey: (src: string) => any) { | |
| const tasks = [ | |
| [page.source, 'url'], | |
| ...page.systems | |
| .map((system) => { | |
| return [ | |
| [system, 'backgroundImage'], | |
| ...system.staves | |
| .map((staff) => [ | |
| [staff, 'backgroundImage'], | |
| [staff, 'maskImage'], | |
| ]) | |
| .flat(), | |
| ]; | |
| }) | |
| .flat(), | |
| ]; | |
| tasks.map(([target, key]: [any, string]) => { | |
| target[key] = onReplaceImageKey(target[key]); | |
| }); | |
| } | |
| export type TaskProgress = { total?: number; finished?: number }; | |
| export interface OMRPage { | |
| url: string | Buffer; | |
| key?: string; | |
| layout?: LayoutResult; | |
| renew?: boolean; | |
| enableGauge?: boolean; | |
| } | |
| export interface ProgressState { | |
| layout?: TaskProgress; | |
| text?: TaskProgress; | |
| gauge?: TaskProgress; | |
| mask?: TaskProgress; | |
| semantic?: TaskProgress; | |
| regulate?: TaskProgress; | |
| brackets?: TaskProgress; | |
| } | |
| class OMRProgress { | |
| state: ProgressState = {}; | |
| onChange: (evt: ProgressState) => void; | |
| constructor(onChange: (evt: ProgressState) => void) { | |
| this.onChange = onChange; | |
| } | |
| setTotal(stage: keyof ProgressState, total: number) { | |
| this.state[stage] = this.state[stage] || { | |
| total, | |
| finished: 0, | |
| }; | |
| } | |
| increase(stage: keyof ProgressState, step = 1) { | |
| const info: TaskProgress = this.state[stage] || { | |
| finished: 0, | |
| }; | |
| info.finished += step; | |
| this.onChange(this.state); | |
| } | |
| } | |
| type SourceImage = string | Buffer; | |
| export interface OMROption { | |
| outputWidth?: number; | |
| title?: string; // 曲谱标题 | |
| pageStore?: { | |
| has?: (key: string) => Promise<Boolean>; | |
| get: (key: string) => Promise<string>; | |
| set: (key: string, val: string) => Promise<void>; | |
| }; | |
| renew?: boolean; | |
| processes?: (keyof ProgressState)[]; // 选择流程 | |
| onProgress?: (progress: ProgressState) => void; | |
| onReplaceImage?: (src: SourceImage) => Promise<string>; // 替换所有图片地址,用于上传或者格式转换 | |
| } | |
| const lruCache = new WeakLRUCache(); | |
| // 默认store | |
| const pageStore = { | |
| async get(key: string) { | |
| return lruCache.getValue(key) as string; | |
| }, | |
| async set(key: string, val: string) { | |
| lruCache.setValue(key, val); | |
| }, | |
| }; | |
| /** | |
| * 默认将图片转换为webp格式的base64字符串 | |
| * @param src | |
| */ | |
| const onReplaceImage = async (src: SourceImage) => { | |
| if (src instanceof Buffer || (typeof src === 'string' && (/^https?:\/\//.test(src) || /^data:image\//.test(src)))) { | |
| const webpBuffer = (await convertImage(src)).buffer; | |
| return `data:image/webp;base64,${webpBuffer.toString('base64')}`; | |
| } | |
| return src; | |
| }; | |
| /** | |
| * 识别所有图片 | |
| * @param pyClients | |
| * @param images | |
| * @param option | |
| */ | |
| export const predictPages = async ( | |
| pyClients: PyClients, | |
| images: OMRPage[], | |
| option: OMROption = { outputWidth: 1200, pageStore, onReplaceImage } | |
| ): Promise<{ score: starry.Score; omitPages: number[]; stat: OMRStat }> => { | |
| const logger = pyClients.logger; | |
| option.outputWidth = option.outputWidth || 1200; | |
| option.pageStore = option.pageStore || pageStore; | |
| option.onReplaceImage = option.onReplaceImage || onReplaceImage; | |
| option.processes = | |
| Array.isArray(option.processes) && option.processes.length > 0 ? option.processes : ['layout', 'text', 'gauge', 'mask', 'semantic', 'brackets']; | |
| const progress: OMRProgress = new OMRProgress(option.onProgress); | |
| const t0 = Date.now(); | |
| // 预处理删除不合法区域 | |
| images.forEach((image) => { | |
| if (image.layout?.detection) { | |
| image.layout.detection.areas = image.layout.detection?.areas?.filter((a) => a?.staves?.middleRhos?.length > 0); | |
| } else { | |
| delete image.layout; | |
| } | |
| }); | |
| const score = new starry.Score({ | |
| title: option?.title, | |
| stavesCount: 2, | |
| paperOptions: { | |
| raggedLast: true, | |
| raggedLastBottom: true, | |
| }, | |
| headers: {}, | |
| instrumentDict: {}, | |
| settings: { | |
| enabledGauge: option.processes.includes('gauge'), | |
| semanticConfidenceThreshold: 1, | |
| }, | |
| }); | |
| logger.info(`[predictor]: download_source_images-${images.length}`); | |
| // 原始拍摄图 | |
| const originalImages: Image[] = await Promise.all(images.map((img) => loadImage(img.url as any))); | |
| logger.info(`[predictor]: source_images_downloaded-${images.length}`); | |
| //const INPUT_IMAGE_WIDTH = images.filter((x) => x?.layout?.interval)?.[0]?.layout?.sourceSize?.width; | |
| /******************************* 布局识别 start *************************/ | |
| // 输入给布局检测的图 | |
| const pageCanvasList: Canvas[] = originalImages.map((img, index) => scaleForLayout(img, images[index]!.layout?.sourceSize?.width ?? img.width)); | |
| progress.setTotal('layout', originalImages.length); | |
| progress.setTotal('text', originalImages.length); | |
| const detections = await Promise.all( | |
| pageCanvasList.map(async (cvs, key) => { | |
| if (!images[key].layout) return (await pyClients.predictScoreImages('layout', [cvs.toBufferSync('png')]))?.[0]; | |
| // reinforce layout from front-end if no gauge | |
| if (!images[key].enableGauge && images[key]?.layout?.detection?.areas?.length) | |
| return (await pyClients.predictScoreImages('layout$reinforce', [cvs.toBufferSync('png')], [images[key].layout]))?.[0]; | |
| return images[key].layout; | |
| }) | |
| ); | |
| detections.forEach((page) => { | |
| page.detection.areas = page.detection?.areas?.filter((a) => a?.staves?.middleRhos?.length > 0); | |
| }); | |
| const imageURLMap = new Map<SourceImage, string>(); | |
| const collectImage = async (source: SourceImage): Promise<void> => { | |
| const url = await option.onReplaceImage(source); | |
| imageURLMap.set(source, url); | |
| }; | |
| // 根据所有页面的宽高比决定全局显示尺寸 | |
| setGlobalPageSize(score, detections, option.outputWidth); | |
| async function createPage(detect, pageIndex) { | |
| const { url, key, layout, enableGauge } = images[pageIndex]; | |
| const pageKey = sha1(JSON.stringify({ key: key || url, layout, enableGauge })); | |
| const cachedPageJson = await option.pageStore.get(pageKey); | |
| const omit = !option.renew && ((cachedPageJson && !images[pageIndex].renew) || !detect.detection.areas?.length); | |
| const page = (score.pages[pageIndex] = | |
| omit && cachedPageJson | |
| ? starry.recoverJSON<starry.Page>(cachedPageJson, starry) | |
| : new starry.Page({ | |
| source: { | |
| name: key || (typeof url === 'string' && /https?:\/\//.test(url) ? url : null), | |
| size: 0, | |
| url, | |
| crop: { | |
| unit: '%', | |
| x: 0, | |
| y: 0, | |
| width: 100, | |
| height: 100, | |
| }, | |
| dimensions: detect.sourceSize, | |
| matrix: [Math.cos(detect.theta), -Math.sin(detect.theta), Math.sin(detect.theta), Math.cos(detect.theta), 0, 0], | |
| interval: detect.interval, | |
| needGauge: images[pageIndex].enableGauge, | |
| }, | |
| layout: detect.detection, | |
| })); | |
| const correctCanvas = omit | |
| ? null | |
| : await shootImageByDetection({ | |
| score, | |
| page, | |
| pageCanvas: pageCanvasList[pageIndex], | |
| }); | |
| progress.increase('layout'); | |
| return { | |
| page, | |
| omit, | |
| hash: pageKey, | |
| correctCanvas, | |
| }; | |
| } | |
| const systemsCount = detections.reduce((acc, x) => acc + (x.detection.areas?.length ?? 0), 0); | |
| const stavesCount = detections.reduce((acc, x) => acc + (x.detection.areas?.reduce?.((a, y) => a + (y.staves?.middleRhos?.length ?? 0), 0) ?? 0), 0); | |
| progress.setTotal('gauge', stavesCount); | |
| progress.setTotal('mask', stavesCount); | |
| progress.setTotal('semantic', stavesCount); | |
| progress.setTotal('brackets', systemsCount); | |
| const allTasks = []; | |
| const omitPages = []; | |
| const t1 = Date.now(); | |
| let n_page = 0; | |
| for (const pageIndex of detections.keys()) { | |
| const pageTasks = []; | |
| const { page, correctCanvas, omit, hash } = await createPage(detections[pageIndex], pageIndex); | |
| pageTasks.push(collectImage(page.source.url)); | |
| pageTasks.push(...page.systems.map((system) => collectImage(system.backgroundImage))); | |
| logger.info(`[predictor]: check_cache_pageIndex-${pageIndex} omit: ${omit}`); | |
| if (omit) { | |
| omitPages.push(pageIndex); | |
| } else { | |
| const staves = page.systems | |
| .map((system, systemIndex) => system.staves.map((staff, staffIndex) => ({ pageIndex, systemIndex, staffIndex, page, system, staff }))) | |
| .flat(1); | |
| await concurrencyTask([ | |
| /******************************* 括号检测 start *************************/ | |
| async () => { | |
| if (!option.processes.includes('brackets')) return; | |
| const detection = page.layout; | |
| const interval = page.source.interval; | |
| const startTime = Date.now(); | |
| const bracketImages = page.systems.map((system, systemIndex) => { | |
| const { | |
| x, | |
| y, | |
| staves: { middleRhos, phi1 }, | |
| } = detection.areas[systemIndex]; | |
| const topMid = middleRhos[0]; | |
| const bottomMid = middleRhos[middleRhos.length - 1]; | |
| const sourceRect = { | |
| x: x + phi1 - 4 * interval, | |
| y: y + topMid - 4 * interval, | |
| width: 8 * interval, | |
| height: bottomMid - topMid + 8 * interval, | |
| }; | |
| const OUTPUT_INTERVAL = 8; | |
| const canvas = new Canvas(OUTPUT_INTERVAL * 8, (sourceRect.height / interval) * OUTPUT_INTERVAL); | |
| const context = canvas.getContext('2d'); | |
| context.drawImage(correctCanvas, sourceRect.x, sourceRect.y, sourceRect.width, sourceRect.height, 0, 0, canvas.width, canvas.height); | |
| // console.log(pageIndex, systemIndex, JSON.stringify(sourceRect), correctCanvas.width, correctCanvas.height) | |
| // const pctx = canvas.getContext('2d') | |
| // pctx.strokeStyle = 'red' | |
| // pctx.fillStyle = 'rgba(255, 0, 0, 0.2)' | |
| // pctx.fillRect(sourceRect.x, sourceRect.y, sourceRect.width, sourceRect.height) | |
| // const area = detections[pageIndex].detection.areas[systemIndex] | |
| // pctx.strokeStyle = 'green' | |
| // pctx.fillStyle = 'rgba(0, 255, 0, 0.1)' | |
| // pctx.fillRect(area.x, area.y, area.width, area.height) | |
| // pctx.fillRect(area.x, area.y, area.width, area.height) | |
| // require('fs').writeFile(`test--system-${systemIndex}.png`, canvas.toBufferSync('png'), () => {}) | |
| // require('fs-extra').writeFile(`test--brackets-${pageIndex}-${systemIndex}.png`, canvas.toBufferSync('png')) | |
| return { | |
| system, | |
| buffer: canvas.toBufferSync('png'), | |
| }; | |
| }); | |
| logger.info(`[predictor]: brackets js [pageIndex-${pageIndex}] duration: ${Date.now() - startTime}`); | |
| const bracketsRes = await pyClients.predictScoreImages('brackets', { buffers: bracketImages.map((x) => x.buffer) }); | |
| progress.increase('brackets', bracketImages.length); | |
| bracketImages.forEach(({ system }, index) => { | |
| if (bracketsRes[index]) { | |
| system.bracketsAppearance = bracketsRes[index]; | |
| } | |
| }); | |
| }, | |
| /******************************* 括号检测 end *************************/ | |
| /******************************* 文本识别 start *************************/ | |
| async () => { | |
| if (!option.processes.includes('text')) return; | |
| try { | |
| const startTime = Date.now(); | |
| // await require('fs-extra').writeFile(`test--text-location-${pageIndex}.png`, correctCanvas.toBufferSync('png')) | |
| const bufferForText = correctCanvas.toBufferSync('png'); | |
| const resultLoc = await pyClients.predictScoreImages('textLoc', [bufferForText]); | |
| const location = resultLoc[0].filter((box) => box.score > 0); | |
| if (location.length > 0) { | |
| const [resultOCR] = await pyClients.predictScoreImages('textOcr', { | |
| buffers: [bufferForText], | |
| location, | |
| }); | |
| page.assignTexts(resultOCR.areas, resultOCR.imageSize); | |
| page.assemble(); | |
| } | |
| logger.info(`[predictor]: text js [pageIndex-${pageIndex}] duration: ${Date.now() - startTime}`); | |
| progress.increase('text'); | |
| if (!option.title) { | |
| const coverTexts: { | |
| confidence: number; | |
| fontSize: number; | |
| id: string; | |
| text: string; | |
| textType: 'Title' | 'Author'; | |
| type: starry.TokenType; | |
| width_: number; | |
| x: number; | |
| y: number; | |
| }[] = score.pages[0].tokens as any; | |
| if (Array.isArray(coverTexts) && coverTexts.length > 0) { | |
| const [titleToken] = coverTexts | |
| .filter((x) => x.type === starry.TokenType.Text && x.textType === 'Title') | |
| .sort((a, b) => b.fontSize - a.fontSize); | |
| if (titleToken) { | |
| score.title = titleToken.text; | |
| } | |
| } | |
| } | |
| } catch (err) { | |
| logger.error(`[predictor]: text js [pageIndex-${pageIndex}] ${JSON.stringify(err)}`); | |
| } | |
| }, | |
| /******************************* 文本识别 end *************************/ | |
| async () => { | |
| /******************************* 变形矫正 start *************************/ | |
| await batchTask(async () => { | |
| const disableGauge = !option.processes.includes('gauge') || images[pageIndex].enableGauge === false; | |
| if (!disableGauge) { | |
| const gaugeRes = await pyClients.predictScoreImages( | |
| 'gauge', | |
| await Promise.all( | |
| staves.map(async ({ staffIndex, system }) => { | |
| const startTime = Date.now(); | |
| const sourceCanvas = await shootStaffImage(system, staffIndex, { | |
| paddingLeft: STAFF_PADDING_LEFT, | |
| spec: GAUGE_VISION_SPEC, | |
| }); | |
| logger.info(`[predictor]: gauge js shoot [page-${pageIndex}, staff-${staffIndex}] duration: ${Date.now() - startTime}`); | |
| return sourceCanvas.toBufferSync('png'); | |
| }) | |
| ) | |
| ); | |
| for (const [index, { system, staff, pageIndex, staffIndex }] of staves.entries()) { | |
| const startTime = Date.now(); | |
| logger.info(`[predictor]: gauge js [page-${pageIndex}, staff-${staffIndex}] start..`); | |
| await gaugeStaff({ | |
| pyClients, | |
| system, | |
| staff, | |
| staffIndex, | |
| gaugeImage: gaugeRes[index].image, | |
| }); | |
| logger.info(`[predictor]: gauge js [page-${pageIndex}, staff-${staffIndex}] duration: ${Date.now() - startTime}`); | |
| progress.increase('gauge'); | |
| pageTasks.push(collectImage(staff.backgroundImage)); | |
| } | |
| } else { | |
| for (const [_, { system, staff, staffIndex }] of staves.entries()) { | |
| await shootStaffBackgroundImage({ | |
| system, | |
| staff, | |
| staffIndex, | |
| }); | |
| pageTasks.push(collectImage(staff.backgroundImage)); | |
| } | |
| } | |
| }); | |
| /******************************* 变形矫正 end *************************/ | |
| await concurrencyTask([ | |
| /******************************* 降噪 start *************************/ | |
| async () => { | |
| if (!option.processes.includes('mask')) return; | |
| const maskRes = await pyClients.predictScoreImages( | |
| 'mask', | |
| staves.map(({ staff }) => staff.backgroundImage as Buffer) | |
| ); | |
| for (const [index, { staff, staffIndex }] of staves.entries()) { | |
| const startTime = Date.now(); | |
| await maskStaff({ | |
| staff, | |
| staffIndex, | |
| maskImage: maskRes[index].image, | |
| }); | |
| logger.info(`[predictor]: mask js [page-${pageIndex}, ${index}, staff-${staffIndex}] duration: ${Date.now() - startTime}`); | |
| progress.increase('mask'); | |
| pageTasks.push(collectImage(staff.maskImage)); | |
| } | |
| }, | |
| /******************************* 降噪 end *************************/ | |
| /******************************* 语义识别 start *************************/ | |
| async () => { | |
| if (!option.processes.includes('semantic')) return; | |
| const semanticRes = starry.recoverJSON<starry.SemanticGraph[]>( | |
| await pyClients.predictScoreImages( | |
| 'semantic', | |
| staves.map(({ staff }) => staff.backgroundImage as Buffer) | |
| ), | |
| starry | |
| ); | |
| staves.forEach(({ system }) => system.clearTokens()); | |
| for (const [index, { staffIndex, system, staff }] of staves.entries()) { | |
| const startTime = Date.now(); | |
| await semanticStaff({ | |
| score, | |
| system, | |
| staff, | |
| staffIndex, | |
| graph: semanticRes[index], | |
| }); | |
| logger.info( | |
| `[predictor]: semantic js [page-${pageIndex}, system-${system.index}, staff-${staff.index}] duration: ${ | |
| Date.now() - startTime | |
| }` | |
| ); | |
| progress.increase('semantic'); | |
| } | |
| }, | |
| /******************************* 语义识别 end *************************/ | |
| ]); | |
| }, | |
| ]); | |
| ++n_page; | |
| } | |
| allTasks.push( | |
| Promise.all(pageTasks).then(() => { | |
| replacePageImages(page, (src) => imageURLMap.get(src)); | |
| logger.info(`[predictor]: pageStore set: [${pageIndex}]`); | |
| return option.pageStore.set(hash, JSON.stringify(page)); | |
| }) | |
| ); | |
| } | |
| const t2 = Date.now(); | |
| await Promise.all(allTasks); | |
| logger.info(`[predictor]: inferenceStaffLayout: ${score.title}, [${score.systems.length}]`); | |
| score.inferenceStaffLayout(); | |
| logger.info(`[predictor]: done: ${score.title}`); | |
| // correct semantic ids | |
| score.assemble(); | |
| const t3 = Date.now(); | |
| return { | |
| score, | |
| omitPages, | |
| stat: { | |
| cost: t3 - t0, | |
| pagesCost: t2 - t1, | |
| pages: n_page, | |
| }, | |
| }; | |
| }; | |
| export const abstractOMRStats = (stats: OMRStat[]): OMRSummary => { | |
| const { costTotal, pagesCostTotal, pagesTotal } = stats.reduce( | |
| (sum, stat) => ({ | |
| costTotal: sum.costTotal + stat.cost, | |
| pagesCostTotal: sum.pagesCostTotal + stat.pagesCost, | |
| pagesTotal: sum.pagesTotal + stat.pages, | |
| }), | |
| { costTotal: 0, pagesCostTotal: 0, pagesTotal: 0 } | |
| ); | |
| return { | |
| costTotal, | |
| costPerPage: pagesTotal ? costTotal / pagesTotal : null, | |
| pagesTotal, | |
| scoreN: stats.length, | |
| }; | |
| }; | |