import sha1 from 'sha1'; import { Canvas, Image, loadImage } from 'skia-canvas'; import { WeakLRUCache } from 'weak-lru-cache'; import * as starry from '../../src/starry'; import { SemanticGraph } from '../../src/starry'; import { LayoutResult, PyClients } from './predictors'; import { constructSystem, convertImage } from './util'; globalThis.OffscreenCanvas = (globalThis as any).OffscreenCanvas || Canvas; (globalThis as any).Image = (globalThis as any).Image || Image; globalThis.btoa = globalThis.btoa || ((str: string) => Buffer.from(str, 'binary').toString('base64')); const STAFF_PADDING_LEFT = 32; const MAX_PAGE_WIDTH = 1200; const GAUGE_VISION_SPEC = { viewportHeight: 256, viewportUnit: 8, }; const MASK_VISION_SPEC = { viewportHeight: 192, viewportUnit: 8, }; const SEMANTIC_VISION_SPEC = { viewportHeight: 192, viewportUnit: 8, }; interface OMRStat { cost: number; // in milliseconds pagesCost: number; // in milliseconds pages: number; } interface OMRSummary { costTotal: number; // in milliseconds costPerPage: number; // in milliseconds pagesTotal: number; scoreN: number; } /** * 为布局识别的图片标准化处理 * @param image * @param width */ function scaleForLayout(image: Image, width: number): Canvas { let height = (image.height / image.width) * width; const canvas = new Canvas(width, height); const ctx = canvas.getContext('2d'); ctx.drawImage(image, 0, 0, width, (width * image.height) / image.width); return canvas; } /** * 根据所有图像的检测结果设置合适的全局页面尺寸 * @param score * @param detections * @param outputWidth */ function setGlobalPageSize(score: starry.Score, detections: LayoutResult[], outputWidth: number) { const sizeRatios = detections .filter((s) => s && s.detection && s.detection.areas?.length) .map((v, k) => { const staffInterval = Math.min(...v.detection.areas.filter((area) => area.staves?.middleRhos?.length).map((x) => x.staves.interval)); const sourceSize = v.sourceSize; return { ...v, index: k, vw: sourceSize.width / staffInterval, // 页面宽度(逻辑单位) hwr: sourceSize.height / sourceSize.width, // 页面高宽比 }; }); if (!sizeRatios.length) { throw new Error('empty result'); } const maxVW = sizeRatios.sort((a, b) => b.vw - a.vw)[0]; const maxAspect = Math.max(...sizeRatios.map((r) => r.hwr)); score.unitSize = outputWidth / maxVW.vw; // 页面显示尺寸 score.pageSize = { width: outputWidth, height: outputWidth * maxAspect, }; } const batchTask = (fn: () => Promise) => fn(); const concurrencyTask = (fns: (() => Promise)[]) => Promise.all(fns.map((fn) => fn())); const shootStaffImage = async ( system: starry.System, staffIndex: number, { paddingLeft = 0, scaling = 1, spec }: { paddingLeft?: number; scaling?: number; spec: { viewportHeight: number; viewportUnit: number } } ): Promise => { if (!system || !system.backgroundImage) return null; const staff = system.staves[staffIndex]; if (!staff) return null; const middleUnits = spec.viewportHeight / spec.viewportUnit / 2; const width = system.imagePosition.width * spec.viewportUnit; const height = system.imagePosition.height * spec.viewportUnit; const x = system.imagePosition.x * spec.viewportUnit + paddingLeft; const y = (system.imagePosition.y - (staff.top + staff.staffY - middleUnits)) * spec.viewportUnit; const canvas = new Canvas(Math.round(width + x) * scaling, spec.viewportHeight * scaling); const context = canvas.getContext('2d'); context.fillStyle = 'white'; context.fillRect(0, 0, canvas.width, canvas.height); context.drawImage(await loadImage(system.backgroundImage), x * scaling, y * scaling, width * scaling, height * scaling); return canvas; // .substr(22); // remove the prefix of 'data:image/png;base64,' }; /** * 根据布局检测结果进行截图 * @param score * @param pageCanvas * @param page * @param detection */ async function shootImageByDetection({ page, score, pageCanvas, }: { score: starry.Score; page: starry.Page; pageCanvas: Canvas; // 原始图片绘制好的canvas }) { if (!page?.layout?.areas?.length) { return null; } page.width = score.pageSize.width / score.unitSize; page.height = score.pageSize.height / score.unitSize; const correctCanvas = new Canvas(pageCanvas.width, pageCanvas.height); const ctx = correctCanvas.getContext('2d'); ctx.save(); const { width, height } = correctCanvas; const [a, b, c, d] = page.source.matrix; ctx.setTransform(a, b, c, d, (-1 / 2) * width + (1 / 2) * a * width + (1 / 2) * b * height, (-1 / 2) * height + (1 / 2) * c * width + (1 / 2) * d * height); ctx.drawImage(pageCanvas, 0, 0); ctx.restore(); const interval = page.source.interval; page.layout.areas.map((area, systemIndex) => { console.assert(area.staves?.middleRhos?.length, '[shootImageByDetection] empty area:', area); const data = ctx.getImageData(area.x, area.y, area.width, area.height); const canvas = new Canvas(area.width, area.height); const context = canvas.getContext('2d'); // context.rotate(-area.staves.theta); context.putImageData(data, 0, 0); const detection = area.staves; const size = { width: area.width, height: area.height }; const sourceCenter = { x: pageCanvas.width / 2 / interval, y: pageCanvas.height / 2 / interval, }; const position = { x: (area.x + area.staves.phi1) / interval - sourceCenter.x + page.width / 2, y: area.y / interval - sourceCenter.y + page.height / 2, }; page.systems[systemIndex] = constructSystem({ page, backgroundImage: canvas.toBufferSync('png'), detection, imageSize: size, position, }); }); return correctCanvas; } async function shootStaffBackgroundImage({ system, staff, staffIndex }: { system: starry.System; staff: starry.Staff; staffIndex: number }) { const sourceCanvas = await shootStaffImage(system, staffIndex, { paddingLeft: STAFF_PADDING_LEFT, spec: SEMANTIC_VISION_SPEC, }); staff.backgroundImage = sourceCanvas.toBufferSync('png'); staff.imagePosition = { x: -STAFF_PADDING_LEFT / SEMANTIC_VISION_SPEC.viewportUnit, y: staff.staffY - SEMANTIC_VISION_SPEC.viewportHeight / 2 / SEMANTIC_VISION_SPEC.viewportUnit, width: sourceCanvas.width / SEMANTIC_VISION_SPEC.viewportUnit, height: sourceCanvas.height / SEMANTIC_VISION_SPEC.viewportUnit, }; } /** * 单个staff的变形矫正 * @param system * @param staff * @param staffIndex * @param gaugeImage * @param pyClients */ async function gaugeStaff({ system, staff, staffIndex, gaugeImage, pyClients, }: { system: starry.System; staff: starry.Staff; staffIndex: number; gaugeImage: Buffer; pyClients: PyClients; }) { const sourceCanvas = await shootStaffImage(system, staffIndex, { paddingLeft: STAFF_PADDING_LEFT, spec: GAUGE_VISION_SPEC, scaling: 2, }); const sourceBuffer = sourceCanvas.toBufferSync('png'); const baseY = (system.middleY - (staff.top + staff.staffY)) * GAUGE_VISION_SPEC.viewportUnit + GAUGE_VISION_SPEC.viewportHeight / 2; const { buffer, size } = await pyClients.predictScoreImages('gaugeRenderer', [sourceBuffer, gaugeImage, baseY]); staff.backgroundImage = buffer; staff.imagePosition = { x: -STAFF_PADDING_LEFT / GAUGE_VISION_SPEC.viewportUnit, y: staff.staffY - size.height / 2 / GAUGE_VISION_SPEC.viewportUnit, width: size.width / GAUGE_VISION_SPEC.viewportUnit, height: size.height / GAUGE_VISION_SPEC.viewportUnit, }; staff.maskImage = null; } /** * 单个staff的降噪 * @param staff * @param staffIndex * @param maskImage */ async function maskStaff({ staff, staffIndex, maskImage }: { staff: starry.Staff; staffIndex: number; maskImage: Buffer }) { const img = await loadImage(maskImage); staff.maskImage = maskImage; staff.imagePosition = { x: -STAFF_PADDING_LEFT / MASK_VISION_SPEC.viewportUnit, y: staff.staffY - MASK_VISION_SPEC.viewportHeight / 2 / MASK_VISION_SPEC.viewportUnit, width: img.width / MASK_VISION_SPEC.viewportUnit, height: img.height / MASK_VISION_SPEC.viewportUnit, }; } /** * 单个staff的语义识别 * @param score * @param staffIndex * @param system * @param staff * @param graph */ async function semanticStaff({ score, staffIndex, system, staff, graph, }: { score: starry.Score; staffIndex: number; system: starry.System; staff: starry.Staff; graph: SemanticGraph; }) { graph.offset(-STAFF_PADDING_LEFT / SEMANTIC_VISION_SPEC.viewportUnit, 0); system.assignSemantics(staffIndex, graph); staff.assignSemantics(graph); staff.clearPredictedTokens(); score.assembleSystem(system, score.settings?.semanticConfidenceThreshold || 1); } function replacePageImages(page: starry.Page, onReplaceImageKey: (src: string) => any) { const tasks = [ [page.source, 'url'], ...page.systems .map((system) => { return [ [system, 'backgroundImage'], ...system.staves .map((staff) => [ [staff, 'backgroundImage'], [staff, 'maskImage'], ]) .flat(), ]; }) .flat(), ]; tasks.map(([target, key]: [any, string]) => { target[key] = onReplaceImageKey(target[key]); }); } export type TaskProgress = { total?: number; finished?: number }; export interface OMRPage { url: string | Buffer; key?: string; layout?: LayoutResult; renew?: boolean; enableGauge?: boolean; } export interface ProgressState { layout?: TaskProgress; text?: TaskProgress; gauge?: TaskProgress; mask?: TaskProgress; semantic?: TaskProgress; regulate?: TaskProgress; brackets?: TaskProgress; } class OMRProgress { state: ProgressState = {}; onChange: (evt: ProgressState) => void; constructor(onChange: (evt: ProgressState) => void) { this.onChange = onChange; } setTotal(stage: keyof ProgressState, total: number) { this.state[stage] = this.state[stage] || { total, finished: 0, }; } increase(stage: keyof ProgressState, step = 1) { const info: TaskProgress = this.state[stage] || { finished: 0, }; info.finished += step; this.onChange(this.state); } } type SourceImage = string | Buffer; export interface OMROption { outputWidth?: number; title?: string; // 曲谱标题 pageStore?: { has?: (key: string) => Promise; get: (key: string) => Promise; set: (key: string, val: string) => Promise; }; renew?: boolean; processes?: (keyof ProgressState)[]; // 选择流程 onProgress?: (progress: ProgressState) => void; onReplaceImage?: (src: SourceImage) => Promise; // 替换所有图片地址,用于上传或者格式转换 } const lruCache = new WeakLRUCache(); // 默认store const pageStore = { async get(key: string) { return lruCache.getValue(key) as string; }, async set(key: string, val: string) { lruCache.setValue(key, val); }, }; /** * 默认将图片转换为webp格式的base64字符串 * @param src */ const onReplaceImage = async (src: SourceImage) => { if (src instanceof Buffer || (typeof src === 'string' && (/^https?:\/\//.test(src) || /^data:image\//.test(src)))) { const webpBuffer = (await convertImage(src)).buffer; return `data:image/webp;base64,${webpBuffer.toString('base64')}`; } return src; }; /** * 识别所有图片 * @param pyClients * @param images * @param option */ export const predictPages = async ( pyClients: PyClients, images: OMRPage[], option: OMROption = { outputWidth: 1200, pageStore, onReplaceImage } ): Promise<{ score: starry.Score; omitPages: number[]; stat: OMRStat }> => { const logger = pyClients.logger; option.outputWidth = option.outputWidth || 1200; option.pageStore = option.pageStore || pageStore; option.onReplaceImage = option.onReplaceImage || onReplaceImage; option.processes = Array.isArray(option.processes) && option.processes.length > 0 ? option.processes : ['layout', 'text', 'gauge', 'mask', 'semantic', 'brackets']; const progress: OMRProgress = new OMRProgress(option.onProgress); const t0 = Date.now(); // 预处理删除不合法区域 images.forEach((image) => { if (image.layout?.detection) { image.layout.detection.areas = image.layout.detection?.areas?.filter((a) => a?.staves?.middleRhos?.length > 0); } else { delete image.layout; } }); const score = new starry.Score({ title: option?.title, stavesCount: 2, paperOptions: { raggedLast: true, raggedLastBottom: true, }, headers: {}, instrumentDict: {}, settings: { enabledGauge: option.processes.includes('gauge'), semanticConfidenceThreshold: 1, }, }); logger.info(`[predictor]: download_source_images-${images.length}`); // 原始拍摄图 const originalImages: Image[] = await Promise.all(images.map((img) => loadImage(img.url as any))); logger.info(`[predictor]: source_images_downloaded-${images.length}`); //const INPUT_IMAGE_WIDTH = images.filter((x) => x?.layout?.interval)?.[0]?.layout?.sourceSize?.width; /******************************* 布局识别 start *************************/ // 输入给布局检测的图 const pageCanvasList: Canvas[] = originalImages.map((img, index) => scaleForLayout(img, images[index]!.layout?.sourceSize?.width ?? img.width)); progress.setTotal('layout', originalImages.length); progress.setTotal('text', originalImages.length); const detections = await Promise.all( pageCanvasList.map(async (cvs, key) => { if (!images[key].layout) return (await pyClients.predictScoreImages('layout', [cvs.toBufferSync('png')]))?.[0]; // reinforce layout from front-end if no gauge if (!images[key].enableGauge && images[key]?.layout?.detection?.areas?.length) return (await pyClients.predictScoreImages('layout$reinforce', [cvs.toBufferSync('png')], [images[key].layout]))?.[0]; return images[key].layout; }) ); detections.forEach((page) => { page.detection.areas = page.detection?.areas?.filter((a) => a?.staves?.middleRhos?.length > 0); }); const imageURLMap = new Map(); const collectImage = async (source: SourceImage): Promise => { const url = await option.onReplaceImage(source); imageURLMap.set(source, url); }; // 根据所有页面的宽高比决定全局显示尺寸 setGlobalPageSize(score, detections, option.outputWidth); async function createPage(detect, pageIndex) { const { url, key, layout, enableGauge } = images[pageIndex]; const pageKey = sha1(JSON.stringify({ key: key || url, layout, enableGauge })); const cachedPageJson = await option.pageStore.get(pageKey); const omit = !option.renew && ((cachedPageJson && !images[pageIndex].renew) || !detect.detection.areas?.length); const page = (score.pages[pageIndex] = omit && cachedPageJson ? starry.recoverJSON(cachedPageJson, starry) : new starry.Page({ source: { name: key || (typeof url === 'string' && /https?:\/\//.test(url) ? url : null), size: 0, url, crop: { unit: '%', x: 0, y: 0, width: 100, height: 100, }, dimensions: detect.sourceSize, matrix: [Math.cos(detect.theta), -Math.sin(detect.theta), Math.sin(detect.theta), Math.cos(detect.theta), 0, 0], interval: detect.interval, needGauge: images[pageIndex].enableGauge, }, layout: detect.detection, })); const correctCanvas = omit ? null : await shootImageByDetection({ score, page, pageCanvas: pageCanvasList[pageIndex], }); progress.increase('layout'); return { page, omit, hash: pageKey, correctCanvas, }; } const systemsCount = detections.reduce((acc, x) => acc + (x.detection.areas?.length ?? 0), 0); const stavesCount = detections.reduce((acc, x) => acc + (x.detection.areas?.reduce?.((a, y) => a + (y.staves?.middleRhos?.length ?? 0), 0) ?? 0), 0); progress.setTotal('gauge', stavesCount); progress.setTotal('mask', stavesCount); progress.setTotal('semantic', stavesCount); progress.setTotal('brackets', systemsCount); const allTasks = []; const omitPages = []; const t1 = Date.now(); let n_page = 0; for (const pageIndex of detections.keys()) { const pageTasks = []; const { page, correctCanvas, omit, hash } = await createPage(detections[pageIndex], pageIndex); pageTasks.push(collectImage(page.source.url)); pageTasks.push(...page.systems.map((system) => collectImage(system.backgroundImage))); logger.info(`[predictor]: check_cache_pageIndex-${pageIndex} omit: ${omit}`); if (omit) { omitPages.push(pageIndex); } else { const staves = page.systems .map((system, systemIndex) => system.staves.map((staff, staffIndex) => ({ pageIndex, systemIndex, staffIndex, page, system, staff }))) .flat(1); await concurrencyTask([ /******************************* 括号检测 start *************************/ async () => { if (!option.processes.includes('brackets')) return; const detection = page.layout; const interval = page.source.interval; const startTime = Date.now(); const bracketImages = page.systems.map((system, systemIndex) => { const { x, y, staves: { middleRhos, phi1 }, } = detection.areas[systemIndex]; const topMid = middleRhos[0]; const bottomMid = middleRhos[middleRhos.length - 1]; const sourceRect = { x: x + phi1 - 4 * interval, y: y + topMid - 4 * interval, width: 8 * interval, height: bottomMid - topMid + 8 * interval, }; const OUTPUT_INTERVAL = 8; const canvas = new Canvas(OUTPUT_INTERVAL * 8, (sourceRect.height / interval) * OUTPUT_INTERVAL); const context = canvas.getContext('2d'); context.drawImage(correctCanvas, sourceRect.x, sourceRect.y, sourceRect.width, sourceRect.height, 0, 0, canvas.width, canvas.height); // console.log(pageIndex, systemIndex, JSON.stringify(sourceRect), correctCanvas.width, correctCanvas.height) // const pctx = canvas.getContext('2d') // pctx.strokeStyle = 'red' // pctx.fillStyle = 'rgba(255, 0, 0, 0.2)' // pctx.fillRect(sourceRect.x, sourceRect.y, sourceRect.width, sourceRect.height) // const area = detections[pageIndex].detection.areas[systemIndex] // pctx.strokeStyle = 'green' // pctx.fillStyle = 'rgba(0, 255, 0, 0.1)' // pctx.fillRect(area.x, area.y, area.width, area.height) // pctx.fillRect(area.x, area.y, area.width, area.height) // require('fs').writeFile(`test--system-${systemIndex}.png`, canvas.toBufferSync('png'), () => {}) // require('fs-extra').writeFile(`test--brackets-${pageIndex}-${systemIndex}.png`, canvas.toBufferSync('png')) return { system, buffer: canvas.toBufferSync('png'), }; }); logger.info(`[predictor]: brackets js [pageIndex-${pageIndex}] duration: ${Date.now() - startTime}`); const bracketsRes = await pyClients.predictScoreImages('brackets', { buffers: bracketImages.map((x) => x.buffer) }); progress.increase('brackets', bracketImages.length); bracketImages.forEach(({ system }, index) => { if (bracketsRes[index]) { system.bracketsAppearance = bracketsRes[index]; } }); }, /******************************* 括号检测 end *************************/ /******************************* 文本识别 start *************************/ async () => { if (!option.processes.includes('text')) return; try { const startTime = Date.now(); // await require('fs-extra').writeFile(`test--text-location-${pageIndex}.png`, correctCanvas.toBufferSync('png')) const bufferForText = correctCanvas.toBufferSync('png'); const resultLoc = await pyClients.predictScoreImages('textLoc', [bufferForText]); const location = resultLoc[0].filter((box) => box.score > 0); if (location.length > 0) { const [resultOCR] = await pyClients.predictScoreImages('textOcr', { buffers: [bufferForText], location, }); page.assignTexts(resultOCR.areas, resultOCR.imageSize); page.assemble(); } logger.info(`[predictor]: text js [pageIndex-${pageIndex}] duration: ${Date.now() - startTime}`); progress.increase('text'); if (!option.title) { const coverTexts: { confidence: number; fontSize: number; id: string; text: string; textType: 'Title' | 'Author'; type: starry.TokenType; width_: number; x: number; y: number; }[] = score.pages[0].tokens as any; if (Array.isArray(coverTexts) && coverTexts.length > 0) { const [titleToken] = coverTexts .filter((x) => x.type === starry.TokenType.Text && x.textType === 'Title') .sort((a, b) => b.fontSize - a.fontSize); if (titleToken) { score.title = titleToken.text; } } } } catch (err) { logger.error(`[predictor]: text js [pageIndex-${pageIndex}] ${JSON.stringify(err)}`); } }, /******************************* 文本识别 end *************************/ async () => { /******************************* 变形矫正 start *************************/ await batchTask(async () => { const disableGauge = !option.processes.includes('gauge') || images[pageIndex].enableGauge === false; if (!disableGauge) { const gaugeRes = await pyClients.predictScoreImages( 'gauge', await Promise.all( staves.map(async ({ staffIndex, system }) => { const startTime = Date.now(); const sourceCanvas = await shootStaffImage(system, staffIndex, { paddingLeft: STAFF_PADDING_LEFT, spec: GAUGE_VISION_SPEC, }); logger.info(`[predictor]: gauge js shoot [page-${pageIndex}, staff-${staffIndex}] duration: ${Date.now() - startTime}`); return sourceCanvas.toBufferSync('png'); }) ) ); for (const [index, { system, staff, pageIndex, staffIndex }] of staves.entries()) { const startTime = Date.now(); logger.info(`[predictor]: gauge js [page-${pageIndex}, staff-${staffIndex}] start..`); await gaugeStaff({ pyClients, system, staff, staffIndex, gaugeImage: gaugeRes[index].image, }); logger.info(`[predictor]: gauge js [page-${pageIndex}, staff-${staffIndex}] duration: ${Date.now() - startTime}`); progress.increase('gauge'); pageTasks.push(collectImage(staff.backgroundImage)); } } else { for (const [_, { system, staff, staffIndex }] of staves.entries()) { await shootStaffBackgroundImage({ system, staff, staffIndex, }); pageTasks.push(collectImage(staff.backgroundImage)); } } }); /******************************* 变形矫正 end *************************/ await concurrencyTask([ /******************************* 降噪 start *************************/ async () => { if (!option.processes.includes('mask')) return; const maskRes = await pyClients.predictScoreImages( 'mask', staves.map(({ staff }) => staff.backgroundImage as Buffer) ); for (const [index, { staff, staffIndex }] of staves.entries()) { const startTime = Date.now(); await maskStaff({ staff, staffIndex, maskImage: maskRes[index].image, }); logger.info(`[predictor]: mask js [page-${pageIndex}, ${index}, staff-${staffIndex}] duration: ${Date.now() - startTime}`); progress.increase('mask'); pageTasks.push(collectImage(staff.maskImage)); } }, /******************************* 降噪 end *************************/ /******************************* 语义识别 start *************************/ async () => { if (!option.processes.includes('semantic')) return; const semanticRes = starry.recoverJSON( await pyClients.predictScoreImages( 'semantic', staves.map(({ staff }) => staff.backgroundImage as Buffer) ), starry ); staves.forEach(({ system }) => system.clearTokens()); for (const [index, { staffIndex, system, staff }] of staves.entries()) { const startTime = Date.now(); await semanticStaff({ score, system, staff, staffIndex, graph: semanticRes[index], }); logger.info( `[predictor]: semantic js [page-${pageIndex}, system-${system.index}, staff-${staff.index}] duration: ${ Date.now() - startTime }` ); progress.increase('semantic'); } }, /******************************* 语义识别 end *************************/ ]); }, ]); ++n_page; } allTasks.push( Promise.all(pageTasks).then(() => { replacePageImages(page, (src) => imageURLMap.get(src)); logger.info(`[predictor]: pageStore set: [${pageIndex}]`); return option.pageStore.set(hash, JSON.stringify(page)); }) ); } const t2 = Date.now(); await Promise.all(allTasks); logger.info(`[predictor]: inferenceStaffLayout: ${score.title}, [${score.systems.length}]`); score.inferenceStaffLayout(); logger.info(`[predictor]: done: ${score.title}`); // correct semantic ids score.assemble(); const t3 = Date.now(); return { score, omitPages, stat: { cost: t3 - t0, pagesCost: t2 - t1, pages: n_page, }, }; }; export const abstractOMRStats = (stats: OMRStat[]): OMRSummary => { const { costTotal, pagesCostTotal, pagesTotal } = stats.reduce( (sum, stat) => ({ costTotal: sum.costTotal + stat.cost, pagesCostTotal: sum.pagesCostTotal + stat.pagesCost, pagesTotal: sum.pagesTotal + stat.pages, }), { costTotal: 0, pagesCostTotal: 0, pagesTotal: 0 } ); return { costTotal, costPerPage: pagesTotal ? costTotal / pagesTotal : null, pagesTotal, scoreN: stats.length, }; };