大致解说请关注B站“qt不会写qt”,这里只贴出代码。
———————————————————————————————————————————
首先,我们需要一段用于处理midi的代码(保存名为data_preprocessing.py,需要与之后的主执行模块main.py放在一个文件夹下)
'''
This module handles the data preprocessing steps for generating tensor files.
'''
import os
import torch
import music21
def process_midi_file(midi_file):
'''
Process a MIDI file and return the relevant data.
'''
midi_data = music21.converter.parse(midi_file)
processed_data = []
for element in midi_data.recurse():
if isinstance(element, music21.note.Note):
pitch = element.pitch.midi
duration = element.duration.quarterLength
processed_data.append((pitch, duration))
return processed_data
def generate_tensor_files():
'''
Generate tensor files from MIDI files in the specified folder.
'''
midi_folder = "H:/PlanB/data/Temporary_midi_folder"
tensor_folder = "H:/PlanB/data/Tensor_files"
# Create the tensor folder if it doesn't exist
if not os.path.exists(tensor_folder):
os.makedirs(tensor_folder)
for file in os.listdir(midi_folder):
if file.endswith(".mid"):
midi_file = os.path.join(midi_folder, file)
tensor_file = os.path.join(tensor_folder, file.replace(".mid", ".pt"))
if not os.path.exists(tensor_file):
try:
tensor_data = process_midi_file(midi_file)
if not isinstance(tensor_data, list):
tensor_data = [tensor_data] # Ensure tensor_data is a list
torch.save(tensor_data, tensor_file)
except Exception as e:
print(f"Error processing MIDI file '{midi_file}': {str(e)}")
def combine_tensor_files():
'''
Combine tensor files into a single file.
'''
tensor_folder = "H:/PlanB/data/Tensor_files"
combined_file = "H:/PlanB/data/combined.pt"
tensor_files = [file for file in os.listdir(tensor_folder) if file.endswith(".pt")]
combined_data = []
for file in tensor_files:
tensor_file = os.path.join(tensor_folder, file)
tensor_data = torch.load(tensor_file)
combined_data.extend(tensor_data)
torch.save(combined_data, combined_file)
print("Combined tensor files into a single file.")
请注意,此处出现些许警告与报错,或者是提醒某些midi没法解析都是非常正常的,如果大家的midi是从网上下载下来的四位数以上的数据集,肯定会有某些midi因为格式问题无法正常载入,目前为止我没有很好的解决办法,为了保证代码顺利运行,我只能让代码在抛出错误后继续运行。
接下来,我们需要一段用于训练模型的代码(截至目前,我的ChatDev仍然在为完善模型结构努力,目前这个模型结构肯定需要修改,所以放出的只是目前的测试版),应命名为model_training.py(也要和主执行模块放在一个文件夹里。)
'''
This module handles the training of the LSTM model.
'''
import os
import torch
import torch.nn as nn
import torch.optim as optim
class LSTMModel(nn.Module):
def __init__(self, input_size, hidden_size, num_layers, output_size):
super(LSTMModel, self).__init__()
# Define the architecture of the LSTM model
self.lstm = nn.LSTM(input_size, hidden_size, num_layers, batch_first=True)
self.fc = nn.Linear(hidden_size, output_size)
def forward(self, x):
# Convert the input tensor to Float
x = x.float()
# Implement the forward pass of the LSTM model
_, (out, _) = self.lstm(x)
out = self.fc(out[-1])
return out.squeeze()
def process_midi_file(midi_file):
'''
Process a MIDI file and return the relevant data.
'''
# Implement the processing of MIDI files
processed_data = []
midi_data = music21.converter.parse(midi_file)
for element in midi_data.recurse():
if isinstance(element, music21.note.Note):
pitch = element.pitch.midi
duration = element.duration.quarterLength
processed_data.append((pitch, duration))
return processed_data
def train_model():
combined_file = "H:/PlanB/data/combined.pt"
model_file = "H:/PlanB/data/model.pt"
checkpoint_file = "H:/PlanB/data/checkpoint.pt"
# Load the combined tensor data
combined_data = torch.load(combined_file)
# Check if the combined data is empty
if not combined_data:
raise ValueError("Combined tensor data is empty.")
# Initialize the LSTM model
input_size = 2 # pitch and duration
hidden_size = 128
output_size = 2 # pitch and duration
num_layers = 1
model = LSTMModel(input_size, hidden_size, num_layers, output_size)
# Define the loss function and optimizer
criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)
# Train the model
for epoch in range(100):
for data in range(len(combined_data)):
if data >= 1:
inputs = combined_data[data]
targets = combined_data[data - 1]
inputs = torch.tensor(inputs).unsqueeze(0)
targets = torch.tensor(targets, dtype=torch.float).view(inputs.shape)
optimizer.zero_grad()
outputs = model(inputs)
loss = criterion(outputs, targets)
loss.backward()
optimizer.step()
else:
pass
print(f"Epoch: {epoch + 1}, Loss: {loss.item()}")
if epoch % 20 == 0:
torch.save(model.state_dict(), model_file)
# Save the final model and create a checkpoint for retraining breakpoints
torch.save(model.state_dict(), model_file)
torch.save({
'epoch': epoch,
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'loss': loss
}, checkpoint_file)
最后,我们定义一下推理音乐的部分(它被内置在主执行文件里了),如下。
'''
This module handles the inference of new songs using the trained model.
'''
import os
import torch
import music21
import torch.nn as nn
import model_training
import data_preprocessing
class LSTMModel(nn.Module):
def __init__(self, input_size, hidden_size, num_layers, output_size):
super(LSTMModel, self).__init__()
# Define the architecture of the LSTM model
self.lstm = nn.LSTM(input_size, hidden_size, num_layers, batch_first=True)
self.fc = nn.Linear(hidden_size, output_size)
def forward(self, x):
# Implement the forward pass of the LSTM model
out, _ = self.lstm(x)
if out.shape[1] > 0:
out = self.fc(out[:, -1].squeeze())
else:
# Handle the case when out tensor is empty
default_output = torch.zeros(self.fc.out_features)
out = default_output
return out.squeeze()
def generate_new_songs():
model_file = "H:/PlanB/data/model.pt"
checkpoint_file = "H:/PlanB/data/checkpoint.pt"
starting_midi_file = "H:/PlanB/file.mid"
output_midi_file = "H:/PlanB/generated_song.mid"
# Load the trained model
input_size = 2 # pitch and duration
hidden_size = 128
output_size = 2 # pitch and duration
num_layers = 1
model = LSTMModel(input_size, hidden_size, num_layers, output_size)
model.load_state_dict(torch.load(model_file))
model.eval()
# Load the checkpoint for retraining breakpoints
checkpoint = torch.load(checkpoint_file)
epoch = checkpoint['epoch']
model.load_state_dict(checkpoint['model_state_dict'])
# Load the starting MIDI file for inference
if not os.path.exists(starting_midi_file):
raise FileNotFoundError(f"MIDI file '{starting_midi_file}' not found.")
starting_midi = music21.converter.parse(starting_midi_file)
# Perform inference to generate new songs
try:
tensor_data = data_preprocessing.process_midi_file(starting_midi_file)
if not isinstance(tensor_data, list):
tensor_data = [tensor_data] # Ensure tensor_data is a list
print("Starting MIDI file:", starting_midi_file)
generated_data = []
for _ in range(100):
print(tensor_data)
inputs = torch.tensor(tensor_data, dtype=torch.float)
inputs = inputs[-1].unsqueeze(0).unsqueeze(0) # Add an extra dimension for batch_size
print(inputs)
outputs = model(inputs)
print(outputs)
generated_data.append((outputs.tolist()[0], outputs.tolist()[1]))
tensor_data = tensor_data + [(outputs.tolist()[0], outputs.tolist()[1])] # Append the output tensor to the existing tensor_data
# Convert the generated tensor data back to MIDI format
generated_midi = music21.stream.Stream()
for data in generated_data:
pitch = int(data[0])
duration = float(data[1])
note = music21.note.Note()
note.pitch.midi = pitch
note.duration.quarterLength = duration
generated_midi.append(note)
# Save the generated song as a MIDI file
if os.path.exists(output_midi_file):
raise FileExistsError(f"MIDI file '{output_midi_file}' already exists.")
generated_midi.write('midi', fp=output_midi_file)
print("Generated song saved as", output_midi_file)
except Exception as e:
print(f"Error generating new songs: {str(e)}")
raise
if __name__ == "__main__":
data_preprocessing.generate_tensor_files()
model_training.train_model()
generate_new_songs()
如果运行过程中有任何问题请直接评论!