from mido import MidiFile, MidiTrack, Message, MetaMessage from collections import defaultdict class Note: def __init__(self, note, start, end, channel, velocity_on, velocity_off): self.note = note self.start = start self.end = end self.channel = channel self.velocity_on = velocity_on self.velocity_off = velocity_off self.voice = None def get_notes_and_events(track): notes = [] ongoing_notes = {} non_note_events = [] absolute_time = 0 for msg in track: absolute_time += msg.time if msg.type == 'note_on' and msg.velocity > 0: key = (msg.note, msg.channel) ongoing_notes[key] = (absolute_time, msg.velocity) elif msg.type == 'note_off' or (msg.type == 'note_on' and msg.velocity == 0): key = (msg.note, msg.channel) if key in ongoing_notes: start_time, velocity_on = ongoing_notes.pop(key) notes.append(Note(msg.note, start_time, absolute_time, msg.channel, velocity_on, msg.velocity)) else: non_note_events.append((absolute_time, msg.copy(time=0))) else: non_note_events.append((absolute_time, msg.copy(time=0))) return notes, non_note_events def assign_voices(notes): notes.sort(key=lambda x: (x.start, x.note)) voices = [] for note in notes: available = [i for i, end in enumerate(voices) if end <= note.start] if available: voice = min(available) voices[voice] = note.end else: voice = len(voices) voices.append(note.end) note.voice = voice return notes, len(voices) def create_track_name(original_name, voice): if voice == 0: return original_name return f"{original_name}_{voice}" def merge_events(note_events, non_note_events, voice): events = [] for abs_time, msg in non_note_events: events.append((abs_time, msg)) for note in note_events: if note.voice != voice: continue events.append((note.start, Message('note_on', note=note.note, velocity=note.velocity_on, channel=note.channel, time=0))) events.append((note.end, Message('note_off', note=note.note, velocity=note.velocity_off, channel=note.channel, time=0))) events.sort(key=lambda x: x[0]) merged_msgs = [] prev_time = 0 for abs_time, msg in events: delta = abs_time - prev_time msg.time = delta merged_msgs.append(msg) prev_time = abs_time return merged_msgs def process_track(track, original_track_index): track_name = f"Track{original_track_index}" for msg in track: if msg.type == 'track_name': track_name = msg.name break notes, non_note_events = get_notes_and_events(track) if not notes: return [track] assigned_notes, num_voices = assign_voices(notes) voice_to_notes = defaultdict(list) for note in assigned_notes: voice_to_notes[note.voice].append(note) new_tracks = [] for voice in range(num_voices): new_track = MidiTrack() new_track.append(MetaMessage('track_name', name=create_track_name(track_name, voice), time=0)) merged_msgs = merge_events(voice_to_notes[voice], non_note_events, voice) new_track.extend(merged_msgs) new_tracks.append(new_track) return new_tracks def process(mid: MidiFile, tracks: set[int] | None = None) -> MidiFile: new_mid = MidiFile() new_mid.ticks_per_beat = mid.ticks_per_beat for i, track in enumerate(mid.tracks): if tracks is not None and i not in tracks: new_mid.tracks.append(track) else: new_tracks = process_track(track, i) new_mid.tracks.extend(new_tracks) return new_mid