个人的练习

import json
import shutil
import time
import warnings

warnings.filterwarnings("ignore")
from collections import Counter
import matplotlib.pyplot as plt
import numpy as np
import tensorflow as tf
from music21 import *
from midi2audio import FluidSynth
import os
from skimage.transform import hough_line, rotate, hough_line_peaks
from skimage.feature import corner_harris
from skimage.measure import label, regionprops
from skimage.color import label2rgb
from skimage.exposure import histogram
from skimage.color import rgb2gray
from skimage.filters import threshold_otsu, gaussian, median
from skimage.morphology import binary_opening, binary_closing, binary_dilation, binary_erosion, opening, \
    square, disk
from skimage.feature import canny
from matplotlib.pyplot import bar
from scipy.ndimage import binary_fill_holes
from skimage.morphology import thin
import cv2
import math
import imutils
from flask import Flask, render_template
from flask import request, jsonify
from flask_cors import CORS, cross_origin
from gevent.greenlet import Greenlet
from gevent.pywsgi import WSGIServer
import skimage.io as io
from mido import Message, MidiFile, MidiTrack
import librosa


plt.rcParams['font.sans-serif'] = ['SimHei']  # 解决中文乱码

headers = {'Content-Type': 'application/json'}
target_img_size = (100, 100)
sample_count = 50
label_map = {
    0: {
        0: 'N0'
    },
    1: {
        0: 'b2',
        1: 'a2'
    },
    2: {
        0: 'g2',
        1: 'f2'
    },
    3: {
        0: 'e2',
        1: 'd2'
    },
    4: {
        0: 'c2',
        1: 'b1'
    },
    5: {
        0: 'a1',
        1: 'g1'
    },
    6: {
        0: 'f1',
        1: 'e1'
    },
    7: {
        0: 'd1',
        1: 'c1'
    }
}
row_percentage = 0.3


class Segmenter(object):
    def __init__(self, bin_img):
        self.bin_img = bin_img
        self.rle, self.vals = hv_rle(self.bin_img)
        self.most_common = get_most_common(self.rle)
        self.thickness, self.spacing = calculate_thickness_spacing(
            self.rle, self.most_common)
        self.thick_space = self.thickness + self.spacing
        self.no_staff_img = remove_staff_lines(
            self.rle, self.vals, self.thickness, self.bin_img.shape)

        self.segment()

    def open_region(self, region):
        thickness = np.copy(self.thickness)
        # if thickness % 2 == 0:
        #     thickness += 1
        return opening(region, np.ones((thickness, thickness)))

    def segment(self):
        self.line_indices = get_line_indices(histogram(self.bin_img, 0.8))
        if len(self.line_indices) < 10:
            self.regions_without_staff = [
                np.copy(self.open_region(self.no_staff_img))]
            self.regions_with_staff = [np.copy(self.bin_img)]
            return

        generated_lines_img = np.copy(self.no_staff_img)
        lines = []
        for index in self.line_indices:
            line = ((0, index), (self.bin_img.shape[1] - 1, index))
            lines.append(line)

        end_of_staff = []
        for index, line in enumerate(lines):
            if index > 0 and (line[0][1] - end_of_staff[-1][1] < 4 * self.spacing):
                pass
            else:
                p1, p2 = line
                x0, y0 = p1
                x1, y1 = p2
                end_of_staff.append((x0, y0, x1, y1))

        box_centers = []
        spacing_between_staff_blocks = []
        for i in range(len(end_of_staff) - 1):
            spacing_between_staff_blocks.append(
                end_of_staff[i + 1][1] - end_of_staff[i][1])
            if i % 2 == 0:
                offset = (end_of_staff[i + 1][1] - end_of_staff[i][1]) // 2
                center = end_of_staff[i][1] + offset
                box_centers.append((center, offset))

        max_staff_dist = np.max(spacing_between_staff_blocks)
        max_margin = max_staff_dist // 2
        margin = max_staff_dist // 10

        end_points = []
        regions_without_staff = []
        regions_with_staff = []
        for index, (center, offset) in enumerate(box_centers):
            y0 = int(center) - max_margin - offset + margin
            y1 = int(center) + max_margin + offset - margin
            end_points.append((y0, y1))

            region = self.bin_img[y0:y1, 0:self.bin_img.shape[1]]
            regions_with_staff.append(region)
            staff_block = self.no_staff_img[y0:y1,
                          0:self.no_staff_img.shape[1]]

            regions_without_staff.append(self.open_region(staff_block))

        self.regions_without_staff = regions_without_staff
        self.regions_with_staff = regions_with_staff


def extract_raw_pixels(img):
    resized = cv2.resize(img, target_img_size)
    return resized.flatten()


def extract_hsv_histogram(img):
    resized = cv2.resize(img, target_img_size)
    hsv = cv2.cvtColor(resized, cv2.COLOR_BGR2HSV)
    hist = cv2.calcHist([hsv], [0, 1, 2], None, [8, 8, 8],
                        [0, 180, 0, 256, 0, 256])
    if imutils.is_cv2():
        hist = cv2.normalize(hist)
    else:
        cv2.normalize(hist, hist)
    return hist.flatten()


def extract_hog_features(img):
    img = cv2.resize(img, target_img_size)
    win_size = (100, 100)
    cell_size = (4, 4)
    block_size_in_cells = (2, 2)

    block_size = (block_size_in_cells[1] * cell_size[1],
                  block_size_in_cells[0] * cell_size[0])
    block_stride = (cell_size[1], cell_size[0])
    nbins = 9  # Number of orientation bins
    hog = cv2.HOGDescriptor(win_size, block_size,
                            block_stride, cell_size, nbins)
    h = hog.compute(img)
    h = h.flatten()
    return h.flatten()


def extract_features(img, feature_set='raw'):
    if feature_set == 'hog':
        return extract_hog_features(img)
    elif feature_set == 'raw':
        return extract_raw_pixels(img)
    else:
        return extract_hsv_histogram(img)


def rle_encode(arr):
    if len(arr) == 0:
        return [], [], []

    x = np.copy(arr)
    first_dismatch = np.array(x[1:] != x[:-1])
    distmatch_positions = np.append(np.where(first_dismatch), len(x) - 1)
    rle = np.diff(np.append(-1, distmatch_positions))
    values = [x[i] for i in np.cumsum(np.append(0, rle))[:-1]]
    return rle, values


def hv_rle(img, axis=1):
    '''
    img: binary image
    axis: 0 for rows, 1 for cols
    '''
    rle, values = [], []

    if axis == 1:
        for i in range(img.shape[1]):
            col_rle, col_values = rle_encode(img[:, i])
            rle.append(col_rle)
            values.append(col_values)
    else:
        for i in range(img.shape[0]):
            row_rle, row_values = rle_encode(img[i])
            rle.append(row_rle)
            values.append(row_values)

    return rle, values


def rle_decode(starts, lengths, values):
    starts, lengths, values = map(np.asarray, (starts, lengths, values))
    ends = starts + lengths
    n = ends[-1]

    x = np.full(n, np.nan)
    for lo, hi, val in zip(starts, ends, values):
        x[lo:hi] = val
    return x


def hv_decode(rle, values, output_shape, axis=1):
    starts = [[int(np.sum(arr[:i])) for i in range(len(arr))] for arr in rle]

    decoded = np.zeros(output_shape, dtype=np.int32)
    if axis == 1:
        for i in range(decoded.shape[1]):
            decoded[:, i] = rle_decode(starts[i], rle[i], values[i])
    else:
        for i in range(decoded.shape[0]):
            decoded[i] = rle_decode(starts[i], rle[i], values[i])

    return decoded


def calculate_pair_sum(arr):
    if len(arr) == 1:
        return list(arr)
    else:
        res = [arr[i] + arr[i + 1] for i in range(0, len(arr) - 1, 2)]
        if len(arr) % 2 == 1:
            res.append(arr[-2] + arr[-1])
        return res


def get_most_common(rle):
    pair_sum = [calculate_pair_sum(col) for col in rle]

    flattened = []
    for col in pair_sum:
        flattened += col

    most_common = np.argmax(np.bincount(flattened))
    return most_common


def most_common_bw_pattern(arr, most_common):
    if len(arr) == 1:
        # print("Empty")
        return []
    else:
        res = [(arr[i], arr[i + 1]) for i in range(0, len(arr) - 1, 2)
               if arr[i] + arr[i + 1] == most_common]

        if len(arr) % 2 == 1 and arr[-2] + arr[-1] == most_common:
            res.append((arr[-2], arr[-1]))
        # print(res)
        return res


class Box(object):
    def __init__(self, x, y, w, h):
        self.x = x
        self.y = y
        self.w = w
        self.h = h
        self.center = x + w / 2, self.y + self.h / 2
        self.area = w * h

    def overlap(self, other):
        x = max(0, min(self.x + self.w, other.x + other.w) - max(other.x, self.x))
        y = max(0, min(self.y + self.h, other.y + other.h) - max(other.y, self.y))
        area = x * y
        return area / self.area

    def distance(self, other):
        return math.sqrt((self.center[0] - other.center[0]) ** 2 + (self.center[1] - other.center[1]) ** 2)

    def merge(self, other):
        x = min(self.x, other.x)
        y = max(self.y, other.y)
        w = max(self.x + self.w, other.x + other.w) - x
        h = max(self.y + self.h, other.y + other.h) - y
        return Box(x, y, w, h)

    def draw(self, img, color, thickness):
        pos = ((int)(self.x), (int)(self.y))
        size = ((int)(self.x + self.w), (int)(self.y + self.h))
        cv2.rectangle(img, pos, size, color, thickness)


def show_images(images, titles=None):
    n_ims = len(images)
    if titles is None:
        titles = ['(%d)' % i for i in range(1, n_ims + 1)]
    fig = plt.figure()
    n = 1
    for image, title in zip(images, titles):
        a = fig.add_subplot(1, n_ims, n)
        if image.ndim == 2:
            plt.gray()
        plt.imshow(image)
        a.set_title(title)
        plt.axis('off')
        n += 1
    fig.set_size_inches(np.array(fig.get_size_inches()) * n_ims)
    plt.show()


def showHist(img):
    plt.figure()
    imgHist = histogram(img, nbins=256)

    bar(imgHist[1].astype(np.uint8), imgHist[0], width=0.8, align='center')


def gray_img(img):
    '''
    img: rgb image
    return: gray image, pixel values 0:255
    '''
    img = img[:, :, :3]
    gray = rgb2gray(img)
    if len(img.shape) == 3:
        gray = gray * 255
    return gray


def otsu(img):
    '''
    Otsu with gaussian
    img: gray image
    return: binary image, pixel values 0:1
    '''
    blur = gaussian(img)
    otsu_bin = 255 * (blur > threshold_otsu(blur))
    return (otsu_bin / 255).astype(np.int32)


def get_gray(img):
    gray = rgb2gray(np.copy(img))
    return gray


def get_thresholded(img, thresh):
    return 1 * (img > thresh)


def histogram(img, thresh):
    hist = (np.ones(img.shape) - img).sum(dtype=np.int32, axis=1)
    _max = np.amax(hist)
    hist[hist[:] < _max * thresh] = 0
    return hist


def get_line_indices(hist):
    indices = []
    prev = 0
    for index, val in enumerate(hist):
        if val > 0 and prev <= 0:
            indices.append(index)
        prev = val
    return indices


def get_region_lines_indices(self, region):
    indices = get_line_indices(histogram(region, 0.8))
    lines = []
    for line_index in indices:
        line = []
        for k in range(self.thickness):
            line.append(line_index + k)
        lines.append(line)
    self.rows.append([np.average(x) for x in lines])


def calculate_thickness_spacing(rle, most_common):
    bw_patterns = [most_common_bw_pattern(col, most_common) for col in rle]
    bw_patterns = [x for x in bw_patterns if x]  # Filter empty patterns

    flattened = []
    for col in bw_patterns:
        flattened += col

    pair, count = Counter(flattened).most_common()[0]

    line_thickness = min(pair)
    line_spacing = max(pair)

    return line_thickness, line_spacing


def whitene(rle, vals, max_height):
    rlv = []
    for length, value in zip(rle, vals):
        if value == 0 and length < 1.1 * max_height:
            value = 1
        rlv.append((length, value))

    n_rle, n_vals = [], []
    count = 0
    for length, value in rlv:
        if value == 1:
            count = count + length
        else:
            if count > 0:
                n_rle.append(count)
                n_vals.append(1)

            count = 0
            n_rle.append(length)
            n_vals.append(0)
    if count > 0:
        n_rle.append(count)
        n_vals.append(1)

    return n_rle, n_vals


def remove_staff_lines(rle, vals, thickness, shape):
    n_rle, n_vals = [], []
    for i in range(len(rle)):
        rl, val = whitene(rle[i], vals[i], thickness)
        n_rle.append(rl)
        n_vals.append(val)

    return hv_decode(n_rle, n_vals, shape)


def remove_staff_lines_2(thickness, img_with_staff):
    img = img_with_staff.copy()
    projected = []
    rows, cols = img.shape
    for i in range(rows):
        proj_sum = 0
        for j in range(cols):
            proj_sum += img[i][j] == 1
        projected.append([1] * proj_sum + [0] * (cols - proj_sum))
        if (proj_sum <= row_percentage * cols):
            img[i, :] = 1
    closed = binary_opening(img, np.ones((3 * thickness, 1)))
    return closed


def get_rows(start, most_common, thickness, spacing):
    # start = start-most_common
    rows = []
    num = 6
    if start - most_common >= 0:
        start -= most_common
        num = 7
    for k in range(num):
        row = []
        for i in range(thickness):
            row.append(start)
            start += 1
        start += (spacing)
        rows.append(row)
    if len(rows) == 6:
        rows = [0] + rows
    return rows


def horizontal_projection(img):
    projected = []
    rows, cols = img.shape
    for i in range(rows):
        proj_sum = 0
        for j in range(cols):
            proj_sum += img[i][j] == 1
        projected.append([1] * proj_sum + [0] * (cols - proj_sum))
        if (proj_sum <= 0.1 * cols):
            return i
    return 0


def get_staff_row_position(img):
    found = 0
    row_position = -1
    for i in range(img.shape[0]):
        for j in range(img.shape[1]):
            if (img[i][j] == 0):
                row_position = i
                found = 1
                break
        if found == 1:
            break
    return row_position


def coordinator(bin_img, horizontal):
    rle, vals = hv_rle(bin_img)
    most_common = get_most_common(rle)
    thickness, spacing = calculate_thickness_spacing(rle, most_common)
    start = 0
    if horizontal:
        no_staff_img = remove_staff_lines_2(thickness, bin_img)
        staff_lines = otsu(bin_img - no_staff_img)
        start = horizontal_projection(bin_img)
    else:
        no_staff_img = remove_staff_lines(rle, vals, thickness, bin_img.shape)
        no_staff_img = binary_closing(
            no_staff_img, np.ones((thickness + 2, thickness + 2)))
        no_staff_img = median(no_staff_img)
        no_staff_img = binary_opening(
            no_staff_img, np.ones((thickness + 2, thickness + 2)))
        staff_lines = otsu(bin_img - no_staff_img)
        staff_lines = binary_erosion(
            staff_lines, np.ones((thickness + 2, thickness + 2)))
        staff_lines = median(staff_lines, selem=square(21))
        start = get_staff_row_position(staff_lines)
    staff_row_positions = get_rows(
        start, most_common, thickness, spacing)
    staff_row_positions = [np.average(x) for x in staff_row_positions]
    return spacing, staff_row_positions, no_staff_img


def deskew(image):
    edges = canny(image, low_threshold=50, high_threshold=150, sigma=2)
    harris = corner_harris(edges)
    tested_angles = np.linspace(-np.pi / 2, np.pi / 2, 360)
    h, theta, d = hough_line(harris, theta=tested_angles)
    out, angles, d = hough_line_peaks(h, theta, d)
    rotation_number = np.average(np.degrees(angles))
    if rotation_number < 45 and rotation_number != 0:
        rotation_number += 90
    return rotation_number


def rotation(img, angle):
    image = rotate(img, angle, resize=True, mode='edge')
    return image


def get_closer(img):
    rows = []
    cols = []
    for x in range(16):
        no = 0
        for col in range(x * img.shape[0] // 16, (x + 1) * img.shape[0] // 16):
            for row in range(img.shape[1]):
                if img[col][row] == 0:
                    no += 1
        if no >= 0.01 * img.shape[1] * img.shape[0] // 16:
            rows.append(x * img.shape[0] // 16)
    for x in range(16):
        no = 0
        for row in range(x * img.shape[1] // 16, (x + 1) * img.shape[1] // 16):
            for col in range(img.shape[0]):
                if img[col][row] == 0:
                    no += 1
        if no >= 0.01 * img.shape[0] * img.shape[1] // 16:
            cols.append(x * img.shape[1] // 16)
    new_img = img[rows[0]:min(img.shape[0], rows[-1] + img.shape[0] // 16),
              cols[0]:min(img.shape[1], cols[-1] + img.shape[1] // 16)]
    return new_img


def IsHorizontal(img):
    projected = []
    rows, cols = img.shape
    for i in range(rows):
        proj_sum = 0
        for j in range(cols):
            if img[i][j] == 0:
                proj_sum += 1
        projected.append([1] * proj_sum + [0] * (cols - proj_sum))
        if (proj_sum >= 0.9 * cols):
            return True
    return False


def get_connected_components(img_without_staff, img_with_staff):
    components = []
    boundary = []
    # thresh = threshold_otsu(img_without_staff)
    # bw = closing(img_without_staff <= thresh, square(3))
    bw = 1 - img_without_staff
    label_img = label(bw)
    img_label_overlay = label2rgb(
        label_img, image=img_without_staff, bg_label=0)
    for region in regionprops(label_img):
        if region.area >= 100:
            boundary.append(region.bbox)

    boundary = sorted(boundary, key=lambda b: b[1])

    comp_with_staff = []
    for bbox in boundary:
        minr, minc, maxr, maxc = bbox
        components.append(img_without_staff[minr:maxr, minc:maxc])
        comp_with_staff.append(img_with_staff[minr:maxr, minc:maxc])
    return components, comp_with_staff, boundary


def estim(c, idx, imgs_spacing, imgs_rows):
    spacing = imgs_spacing[idx]
    rows = imgs_rows[idx]
    margin = 1 + (spacing / 4)
    for index, line in enumerate(rows):
        if c >= line - margin and c <= line + margin:
            return index + 1, 0
        elif c >= line + margin and c <= line + 3 * margin:
            return index + 1, 1
    return 7, 1


def get_note_name(prev, octave, duration):
    if duration in ['4', 'a_4']:
        return f'{octave[0]}{prev}{octave[1]}/4'
    elif duration in ['8', '8_b_n', '8_b_r', 'a_8']:
        return f'{octave[0]}{prev}{octave[1]}/8'
    elif duration in ['16', '16_b_n', '16_b_r', 'a_16']:
        return f'{octave[0]}{prev}{octave[1]}/16'
    elif duration in ['32', '32_b_n', '32_b_r', 'a_32']:
        return f'{octave[0]}{prev}{octave[1]}/32'
    elif duration in ['2', 'a_2']:
        return f'{octave[0]}{prev}{octave[1]}/2'
    elif duration in ['1', 'a_1']:
        return f'{octave[0]}{prev}{octave[1]}/1'
    else:
        return "c1/4"


def filter_beams(prims, prim_with_staff, bounds):
    n_bounds = []
    n_prims = []
    n_prim_with_staff = []
    for i, prim in enumerate(prims):
        if prim.shape[1] >= 2 * prim.shape[0]:
            continue
        else:
            n_bounds.append(bounds[i])
            n_prims.append(prims[i])
            n_prim_with_staff.append(prim_with_staff[i])
    return n_prims, n_prim_with_staff, n_bounds


def get_chord_notation(chord_list):
    chord_res = "{"
    for chord_note in chord_list:
        chord_res += (str(chord_note) + ",")
    chord_res = chord_res[:-1]
    chord_res += "}"

    return chord_res


label_str = ['bar_121', '32_b_r_049', '8_b_n_017', 't44_200', '8_b_r_021', 'bar_b_128', 't44_b_196', 'flat_b_165',
             '8_013', 'a_16_091', 't24_b_193', 'a_4_071', 'flat_b_163', '32_036', 't4_186', 'bar_b_125', 'a_1_050',
             'dot_b_152', 'a_2_062', 'a_2_058', 'bar_119', 'a_2_060', '32_b_r_043', 'a_8_079', 'p_180', 'natural_169',
             'bar_099', 'bar_111', 'flat_b_164', '#', 't2_185', 'bar_106', 'natural_b_170', 'bar_115', 't44_197',
             'a_4_068', 'a_2_056', 'natural_167', 'dot_147', '#_b_007', 'natural_b_169', '2_006', '8_015', 't24_191',
             'a_8_077', 'a_16_092', 'a_8_086', 'natural_165', 'a_2_059', '1_001', 'a_4_076', 'a_1_049', 'a_4_066',
             't44_198', '16_022', 'a_2_066', 'natural_b_174', 't44_201', '#_b_002', 't24_192', 't24_190', 'flat_157',
             'clef_135', 'natural_168', 't24_b_195', '8_014', 'natural_b_177', 'bar_098', 'a_8_083', 'a_32_096',
             'a_8_082', '8_b_r_019', '32_b_n_042', 'clef_b_137', 'a_1_051', 'bar_123', 'bar_107', 'a_16_086', 't2_184',
             '#_b_009', 'bar_101', 't44_199', 'a_1_052', 'bar_109', '2_007', '32_037', '2_008', 't24_b_194', 't2_181',
             'bar_b_129', '16_024', '16_b_r_030', '32_b_r_046', 'chord_132', '16_b_r_036', 'flat_156', '32_038',
             '32_b_r_047', 'a_8_080', 'a_1_055', 'clef_137', 'a_8_081', '16_b_n_025', 'a_16_087', 'bar_125',
             '16_b_r_031', '32_b_n_041', 'a_2_063', 'a_4_072', 'a_4_074', 'bar_103', '16_b_r_033', 'p_179', 't44_b_195',
             'dot_150', '#_b_001', '16_b_n_027', '32_b_n_039', 'flat_b_161', 'clef_136', 't4_189', 't2_183', '#_2',
             'a_8_084', 'natural_166', 'bar_b_127', 't24_193', 'clef_b_138', 'flat_b_160', 'flat_b_162', 'dot_149',
             'bar_097', 'clef_b_140', 'a_8_085', 'a_4_069', 'dot_b_151', '16_b_r_035', 'a_2_064', 'clef_b_141', '4_011',
             'clef_133', 'a_32_095', 'clef_b_142', '1_004', 'bar_116', '16_025', 'chord_130', '8_b_r_018', 't24_189',
             'a_2_065', 'dot_b_150', '#_b_004', 'dot_148', '32_b_n_038', '#_3', 'a_4_073', '8_016', 'bar_122',
             '#_b_005', 'a_4_067', '16_b_r_032', 'chord_133', '#_1', 'bar_105', 't4_185', 'flat_153', '32_b_r_045',
             'a_4_070', 'natural_b_175', 'bar_118', 'dot_146', '#_b_006', 'flat_b_158', '16_b_n_028', 'bar_104',
             'natural_b_173', 'a_4_077', '1_002', '8_b_r_020', '#_b_008', '16_023', '32_b_r_048', 'a_4_075', '4_010',
             'bar_110', 't4_188', 'natural_b_172', 'natural_b_176', '16_b_n_026', 'bar_124', '32_b_n_040', 'p_177',
             't2_182', 'bar_b_126', 'p_178', 'a_1_053', 'flat_154', 'a_2_061', 'a_1_056', '2_005', '32_b_r_044',
             'bar_112', 't4_187', 'bar_120', 'bar_108', 'a_16_089', 'flat_b_157', 'clef_b_139', 'a_32_094', '16_021',
             'a_16_088', '1_003', 'a_8_078', 'bar_114', 'dot_b_153', 'clef_134', '8_b_n_018', 'a_16_094', 'flat_155',
             'p_181', 'bar_117', 'a_2_057', 'bar_102', 'bar_100', '16_b_n_030', '16_b_n_029', '#_b_003', 'a_16_093',
             'a_32_097', '32_b_n_043', '16_b_r_034', '4_012', 'a_1_054', 'chord_129', 'natural_b_171', 'bar_113',
             '4_009', 't44_b_197', 'chord_131', 'flat_b_159', 'a_16_090']

pic_model = tf.keras.models.load_model('resources/frame.h5')


# model = pickle.load(open('resources/nn_trained_model_hog.sav', 'rb'))
def predict(img_path):
    features = cv2.resize(img_path, (100, 100)) / 255.0
    features = features[np.newaxis, :, :, np.newaxis]
    features = np.concatenate([features, features, features], -1)
    label_id = pic_model(features).numpy()[0]
    label_id = np.argmax(label_id)
    return label_str[label_id]


def recognize(out_file, most_common, coord_imgs, imgs_with_staff, imgs_spacing, imgs_rows):
    black_names = ['4', '8', '8_b_n', '8_b_r', '16', '16_b_n', '16_b_r',
                   '32', '32_b_n', '32_b_r', 'a_4', 'a_8', 'a_16', 'a_32', 'chord']
    ring_names = ['2', 'a_2']
    whole_names = ['1', 'a_1']
    disk_size = most_common / 4
    if len(coord_imgs) > 1:
        out_file.write("{\n")
    for i, img in enumerate(coord_imgs):
        res = []
        prev = ''
        time_name = ''
        primitives, prim_with_staff, boundary = get_connected_components(
            img, imgs_with_staff[i])
        for j, prim in enumerate(primitives):
            prim = binary_opening(prim, square(
                np.abs(most_common - imgs_spacing[i])))
            saved_img = (255 * (1 - prim)).astype(np.uint8)
            labels = predict(saved_img)
            octave = None
            label = labels[0]
            if label in black_names:
                test_img = np.copy(prim_with_staff[j])
                test_img = binary_dilation(test_img, disk(disk_size))
                comps, comp_w_staff, bounds = get_connected_components(
                    test_img, prim_with_staff[j])
                comps, comp_w_staff, bounds = filter_beams(
                    comps, comp_w_staff, bounds)
                bounds = [np.array(bound) + disk_size - 2 for bound in bounds]

                if len(bounds) > 1 and label not in ['8_b_n', '8_b_r', '16_b_n', '16_b_r', '32_b_n', '32_b_r']:
                    l_res = []
                    bounds = sorted(bounds, key=lambda b: -b[2])
                    for k in range(len(bounds)):
                        idx, p = estim(
                            boundary[j][0] + bounds[k][2], i, imgs_spacing, imgs_rows)
                        l_res.append(f'{label_map[idx][p]}/4')
                        if k + 1 < len(bounds) and (bounds[k][2] - bounds[k + 1][2]) > 1.5 * imgs_spacing[i]:
                            idx, p = estim(
                                boundary[j][0] + bounds[k][2] - imgs_spacing[i] / 2, i, imgs_spacing, imgs_rows)
                            l_res.append(f'{label_map[idx][p]}/4')
                    res.append(sorted(l_res))
                else:
                    for bbox in bounds:
                        c = bbox[2] + boundary[j][0]
                        line_idx, p = estim(int(c), i, imgs_spacing, imgs_rows)
                        l = label_map[line_idx][p]
                        res.append(get_note_name(prev, l, label))
            elif label in ring_names:
                head_img = 1 - binary_fill_holes(1 - prim)
                head_img = binary_closing(head_img, disk(disk_size))
                comps, comp_w_staff, bounds = get_connected_components(
                    head_img, prim_with_staff[j])
                for bbox in bounds:
                    c = bbox[2] + boundary[j][0]
                    line_idx, p = estim(int(c), i, imgs_spacing, imgs_rows)
                    l = label_map[line_idx][p]
                    res.append(get_note_name(prev, l, label))
            elif label in whole_names:
                c = boundary[j][2]
                line_idx, p = estim(int(c), i, imgs_spacing, imgs_rows)
                l = label_map[line_idx][p]
                res.append(get_note_name(prev, l, label))
            elif label in ['bar', 'bar_b', 'clef', 'clef_b', 'natural', 'natural_b', 't24', 't24_b', 't44',
                           't44_b'] or label in []:
                continue
            elif label in ['#', '#_b']:
                if prim.shape[0] == prim.shape[1]:
                    prev = '##'
                else:
                    prev = '#'
            elif label in ['cross']:
                prev = '##'
            elif label in ['flat', 'flat_b']:
                if prim.shape[1] >= 0.5 * prim.shape[0]:
                    prev = '&&'
                else:
                    prev = '&'
            elif label in ['dot', 'dot_b', 'p']:
                if len(res) == 0 or (
                        len(res) > 0 and res[-1] in ['flat', 'flat_b', 'cross', '#', '#_b', 't24', 't24_b', 't44',
                                                     't44_b']):
                    continue
                res[-1] += '.'
            elif label in ['t2', 't4']:
                time_name += label[1]
            elif label == 'chord':
                img = thin(1 - prim.copy(), max_iter=20)
                head_img = binary_closing(1 - img, disk(disk_size))
            if label not in ['flat', 'flat_b', 'cross', '#', '#_b']:
                prev = ''
        if len(time_name) == 2:
            out_file.write("[ " + "\\" + "meter<\"" + str(time_name[0]) + "/" + str(time_name[1]) + "\">" + ' '.join(
                [str(elem) if type(elem) != list else get_chord_notation(elem) for elem in res]) + "]\n")
        elif len(time_name) == 1:
            out_file.write("[ " + "\\" + "meter<\"" + '4' + "/" + '2' + "\">" + ' '.join(
                [str(elem) if type(elem) != list else get_chord_notation(elem) for elem in res]) + "]\n")
        else:
            out_file.write("[ " + ' '.join(
                [str(elem) if type(elem) != list else get_chord_notation(elem) for elem in res]) + "]\n")

    if len(coord_imgs) > 1:
        out_file.write("}")
    print("###########################", res, "##########################")


def start_server():
    app = make_handle_data()
    CORS(app, resources=r'/*')

    http_server = WSGIServer(('0.0.0.0', port), app)
    http_server.start()

    print(f"serving start on port {port}")
    # return app
    return http_server


def model_result(img_path):
    img = io.imread(img_path)
    img = gray_img(img)
    horizontal = IsHorizontal(img)
    if horizontal == False:
        theta = deskew(img)
        img = rotation(img, theta)
        img = get_gray(img)
        img = get_thresholded(img, threshold_otsu(img))
        img = get_closer(img)
        horizontal = IsHorizontal(img)

    original = img.copy()
    gray = get_gray(img)
    bin_img = get_thresholded(gray, threshold_otsu(gray))

    segmenter = Segmenter(bin_img)
    imgs_with_staff = segmenter.regions_with_staff
    most_common = segmenter.most_common

    imgs_spacing = []
    imgs_rows = []
    coord_imgs = []
    for i, img in enumerate(imgs_with_staff):
        spacing, rows, no_staff_img = coordinator(img, horizontal)
        imgs_rows.append(rows)
        imgs_spacing.append(spacing)
        coord_imgs.append(no_staff_img)

    print("Recognize...")
    out_file = open(f'static/new.txt', "w")
    recognize(out_file, most_common, coord_imgs,
              imgs_with_staff, imgs_spacing, imgs_rows)
    out_file.close()

    with open(f'static/new.txt', 'r') as f:
        line = f.readlines()[0]
    line = line.replace('[', '').replace(']', '').replace('\n', '').replace('{', '').replace('}', '').replace('#',
                                                                                                              '').replace(
        '&', '')
    line = line.split()
    mid = MidiFile()  # 给自己的文件定的.mid后缀
    track = MidiTrack()  # 定义声部,一个MidoTrack()就是一个声部
    mid.tracks.append(track)  # 这一句一定要加,表示将这个轨道加入到文件中,否则打不开(后面的play)
    track.append(Message('program_change', channel=0, program=0, time=0))
    for l in line:
        l = l.split('/')
        print(l[0])
        note = librosa.note_to_midi(l[0]) + 21
        track.append(
            Message('note_on', note=note, velocity=64, time=1000, channel=0))
        track.append(Message('note_off', note=note, velocity=64, time=1000, channel=0))

    mid.save(f'static/new.mid')
    midi_data = converter.parse("static/new.mid")

    # Convert MIDI data to MusicXML format
    xml_data = midi_data[0].write('musicxml')
    shutil.move(xml_data, 'static/new.musicxml')
    wav_path = f'new_{np.random.randint(1, 10000)}.wav'
    sy = FluidSynth(sound_font='resources/GS.sf2', sample_rate=16000)
    sy.midi_to_audio('static/new.mid', f'static/{wav_path}')
    return wav_path


def make_handle_data():
    app = Flask(__name__, template_folder="static", static_folder="static", static_url_path="/")
    CORS(app)

    @app.route('/')
    def home():
        return render_template("login.html")

    @app.route('/recognition', methods=['POST', 'OPTIONS'])
    @cross_origin()
    def upload_audio():
        format_time = time.strftime("%m_%d_%H_%M_%S", time.localtime(time.time()))

        upload_path = 'resources'
        os.makedirs(upload_path, exist_ok=True)

        # 存储音频
        img_files = []
        index = 1
        upload_files = request.files.to_dict()
        for file_name in upload_files:
            file = upload_files[file_name]
            img_path = f'{upload_path}/{format_time}_{index}.jpg'
            file.save(img_path)
            img_files.append(img_path)
            index = index + 1

        return jsonify({'result': model_result(img_files[0])})

    @app.route('/download_audio', methods=['POST', 'GET'])
    def download_audio():
        return app.send_static_file('new.musicxml')

    @app.route('/login', methods=['POST', 'OPTIONS'])
    @cross_origin()
    def login():
        # 获取value值
        username = request.json['username']
        password = request.json['password']

        with open('static/user.json', 'r') as f:
            user_info = json.load(f)

        is_match = 0
        for users in user_info['userinfo']:
            if users[0] == username and users[1] == password:
                is_match = 1
                break

        return jsonify({'result': is_match})

    @app.route('/register', methods=['POST', 'OPTIONS'])
    def register():
        # 获取value值
        username = request.json['username']
        password = request.json['password']

        with open('static/user.json', 'r') as f:
            user_info = json.load(f)

        users = user_info['userinfo']
        users.append([username, password])

        with open('static/user.json', 'w') as f:
            json.dump(user_info, f)

        return jsonify({'result': 'OK'})

    return app


if __name__ == '__main__':
    port = 8000
    http_server = start_server()
    try:
        http_server._stop_event.wait()
    finally:
        Greenlet.spawn(http_server.stop, timeout=http_server.stop_timeout).join()

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值