You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
122 lines
3.8 KiB
122 lines
3.8 KiB
from mido import MidiFile, MidiTrack, Message, MetaMessage
|
|
from collections import defaultdict
|
|
|
|
|
|
class Note:
|
|
def __init__(self, note, start, end, channel, velocity_on, velocity_off):
|
|
self.note = note
|
|
self.start = start
|
|
self.end = end
|
|
self.channel = channel
|
|
self.velocity_on = velocity_on
|
|
self.velocity_off = velocity_off
|
|
self.voice = None
|
|
|
|
|
|
def get_notes_and_events(track):
|
|
notes = []
|
|
ongoing_notes = {}
|
|
non_note_events = []
|
|
absolute_time = 0
|
|
for msg in track:
|
|
absolute_time += msg.time
|
|
if msg.type == 'note_on' and msg.velocity > 0:
|
|
key = (msg.note, msg.channel)
|
|
ongoing_notes[key] = (absolute_time, msg.velocity)
|
|
elif msg.type == 'note_off' or (msg.type == 'note_on' and msg.velocity == 0):
|
|
key = (msg.note, msg.channel)
|
|
if key in ongoing_notes:
|
|
start_time, velocity_on = ongoing_notes.pop(key)
|
|
notes.append(Note(msg.note, start_time, absolute_time, msg.channel, velocity_on, msg.velocity))
|
|
else:
|
|
non_note_events.append((absolute_time, msg.copy(time=0)))
|
|
else:
|
|
non_note_events.append((absolute_time, msg.copy(time=0)))
|
|
return notes, non_note_events
|
|
|
|
|
|
def assign_voices(notes):
|
|
notes.sort(key=lambda x: (x.start, x.note))
|
|
voices = []
|
|
for note in notes:
|
|
available = [i for i, end in enumerate(voices) if end <= note.start]
|
|
if available:
|
|
voice = min(available)
|
|
voices[voice] = note.end
|
|
else:
|
|
voice = len(voices)
|
|
voices.append(note.end)
|
|
note.voice = voice
|
|
return notes, len(voices)
|
|
|
|
|
|
def create_track_name(original_name, voice):
|
|
if voice == 0:
|
|
return original_name
|
|
return f"{original_name}_{voice}"
|
|
|
|
|
|
def merge_events(note_events, non_note_events, voice):
|
|
events = []
|
|
|
|
for abs_time, msg in non_note_events:
|
|
events.append((abs_time, msg))
|
|
|
|
for note in note_events:
|
|
if note.voice != voice:
|
|
continue
|
|
events.append((note.start, Message('note_on', note=note.note, velocity=note.velocity_on, channel=note.channel, time=0)))
|
|
events.append((note.end, Message('note_off', note=note.note, velocity=note.velocity_off, channel=note.channel, time=0)))
|
|
|
|
events.sort(key=lambda x: x[0])
|
|
|
|
merged_msgs = []
|
|
prev_time = 0
|
|
for abs_time, msg in events:
|
|
delta = abs_time - prev_time
|
|
msg.time = delta
|
|
merged_msgs.append(msg)
|
|
prev_time = abs_time
|
|
|
|
return merged_msgs
|
|
|
|
|
|
def process_track(track, original_track_index):
|
|
track_name = f"Track{original_track_index}"
|
|
for msg in track:
|
|
if msg.type == 'track_name':
|
|
track_name = msg.name
|
|
break
|
|
|
|
notes, non_note_events = get_notes_and_events(track)
|
|
if not notes:
|
|
return [track]
|
|
|
|
assigned_notes, num_voices = assign_voices(notes)
|
|
|
|
voice_to_notes = defaultdict(list)
|
|
for note in assigned_notes:
|
|
voice_to_notes[note.voice].append(note)
|
|
|
|
new_tracks = []
|
|
for voice in range(num_voices):
|
|
new_track = MidiTrack()
|
|
new_track.append(MetaMessage('track_name', name=create_track_name(track_name, voice), time=0))
|
|
merged_msgs = merge_events(voice_to_notes[voice], non_note_events, voice)
|
|
new_track.extend(merged_msgs)
|
|
new_tracks.append(new_track)
|
|
return new_tracks
|
|
|
|
|
|
def process(mid: MidiFile, tracks: set[int] | None = None) -> MidiFile:
|
|
new_mid = MidiFile()
|
|
new_mid.ticks_per_beat = mid.ticks_per_beat
|
|
|
|
for i, track in enumerate(mid.tracks):
|
|
if tracks is not None and i not in tracks:
|
|
new_mid.tracks.append(track)
|
|
else:
|
|
new_tracks = process_track(track, i)
|
|
new_mid.tracks.extend(new_tracks)
|
|
|
|
return new_mid
|