import os import uuid from dataclasses import dataclass, field from fastapi import APIRouter, UploadFile, File, HTTPException from fastapi.responses import Response from pydantic import BaseModel import mido from ..core.file_handling import load_midi_from_bytes, midi_to_bytes from ..core.analyze import analyze_midi from ..core.track_detail import get_track_detail from ..core.midi_utils import has_musical_messages, get_instrument_name from ..core import baketempo as baketempo_core from ..core import monofy as monofy_core from ..core import reduncheck as reduncheck_core from ..core import velfix as velfix_core from ..core import type0 as type0_core router = APIRouter(prefix="/api/session") # In-memory session store sessions: dict[str, "Session"] = {} @dataclass class Session: midi_bytes: bytes original_name: str undo_stack: list[bytes] = field(default_factory=list) history: list[str] = field(default_factory=list) class ApplyRequest(BaseModel): tool: str channels: list[int] | None = None vel_min: int | None = None vel_max: int | None = None tracks: list[int] | None = None class TrackEditRequest(BaseModel): channel: int | None = None program: int | None = None class MergeRequest(BaseModel): tracks: list[int] def _musical_to_raw_indices(midi, musical_indices: set[int]) -> set[int]: """Convert 0-based musical track indices to raw MIDI file track indices.""" raw = set() musical_idx = 0 for i, track in enumerate(midi.tracks): if not has_musical_messages(track): continue if musical_idx in musical_indices: raw.add(i) musical_idx += 1 return raw def _find_musical_track(midi, track_index: int): """Find a musical track by its 0-based musical index. Returns (raw_index, track) or (None, None).""" musical_idx = 0 for i, track in enumerate(midi.tracks): if not has_musical_messages(track): continue if musical_idx == track_index: return i, track musical_idx += 1 return None, None def _analyze_session(session: Session) -> dict: midi = load_midi_from_bytes(session.midi_bytes) filename = os.path.splitext(session.original_name)[0] return analyze_midi(midi, filename) @router.post("/upload") async def upload(file: UploadFile = File(...)): if not file.filename.lower().endswith(('.mid', '.midi')): raise HTTPException(400, "File must be a .mid or .midi file") content = await file.read() try: midi = load_midi_from_bytes(content) except Exception as e: raise HTTPException(400, f"Invalid MIDI file: {e}") session_id = str(uuid.uuid4()) session = Session(midi_bytes=content, original_name=file.filename) sessions[session_id] = session filename = os.path.splitext(file.filename)[0] analysis = analyze_midi(midi, filename) return { "session_id": session_id, "analysis": analysis, "history": [] } @router.post("/{session_id}/apply") async def apply_tool(session_id: str, request: ApplyRequest): session = sessions.get(session_id) if not session: raise HTTPException(404, "Session not found") channels = set(request.channels) if request.channels else None # Push current state to undo stack session.undo_stack.append(session.midi_bytes) try: midi = load_midi_from_bytes(session.midi_bytes) except Exception as e: session.undo_stack.pop() raise HTTPException(500, f"Failed to load MIDI: {e}") # Build history label label = _tool_label(request) try: if request.tool == "baketempo": result = baketempo_core.process(midi) elif request.tool == "monofy": raw_tracks = None if request.tracks is not None: raw_tracks = _musical_to_raw_indices(midi, set(request.tracks)) result = monofy_core.process(midi, raw_tracks) elif request.tool == "reduncheck": raw_tracks = None if request.tracks is not None: raw_tracks = _musical_to_raw_indices(midi, set(request.tracks)) result = reduncheck_core.process(midi, raw_tracks) elif request.tool == "velfix": if request.vel_min is None or request.vel_max is None: session.undo_stack.pop() raise HTTPException(400, "vel_min and vel_max are required for velfix") if not (0 <= request.vel_min <= 127) or not (0 <= request.vel_max <= 127): session.undo_stack.pop() raise HTTPException(400, "velocities must be 0-127") if request.vel_min > request.vel_max: session.undo_stack.pop() raise HTTPException(400, "vel_min must be <= vel_max") raw_tracks = None if request.tracks is not None: raw_tracks = _musical_to_raw_indices(midi, set(request.tracks)) result = velfix_core.process(midi, request.vel_min, request.vel_max, raw_tracks) elif request.tool == "type0": result = type0_core.process(midi) else: session.undo_stack.pop() raise HTTPException(400, f"Unknown tool: {request.tool}") except HTTPException: raise except Exception as e: session.undo_stack.pop() raise HTTPException(500, f"Processing error: {e}") session.midi_bytes = midi_to_bytes(result) session.history.append(label) analysis = _analyze_session(session) return { "analysis": analysis, "history": session.history } @router.post("/{session_id}/undo") async def undo(session_id: str): session = sessions.get(session_id) if not session: raise HTTPException(404, "Session not found") if not session.undo_stack: raise HTTPException(400, "Nothing to undo") session.midi_bytes = session.undo_stack.pop() session.history.pop() analysis = _analyze_session(session) return { "analysis": analysis, "history": session.history } @router.get("/{session_id}/download") async def download(session_id: str): session = sessions.get(session_id) if not session: raise HTTPException(404, "Session not found") base = os.path.splitext(session.original_name)[0] filename = f"{base}_edited.mid" if session.history else session.original_name return Response( content=session.midi_bytes, media_type="audio/midi", headers={"Content-Disposition": f'attachment; filename="{filename}"'} ) @router.get("/{session_id}/track/{track_index}") async def track_detail(session_id: str, track_index: int): session = sessions.get(session_id) if not session: raise HTTPException(404, "Session not found") midi = load_midi_from_bytes(session.midi_bytes) detail = get_track_detail(midi, track_index) if detail is None: raise HTTPException(404, "Track not found") return detail @router.post("/{session_id}/track/{track_index}/edit") async def edit_track(session_id: str, track_index: int, request: TrackEditRequest): session = sessions.get(session_id) if not session: raise HTTPException(404, "Session not found") if request.channel is not None and not (1 <= request.channel <= 16): raise HTTPException(400, "Channel must be 1-16") if request.program is not None and not (0 <= request.program <= 127): raise HTTPException(400, "Program must be 0-127") session.undo_stack.append(session.midi_bytes) try: midi = load_midi_from_bytes(session.midi_bytes) except Exception as e: session.undo_stack.pop() raise HTTPException(500, f"Failed to load MIDI: {e}") raw_idx, target_track = _find_musical_track(midi, track_index) if target_track is None: session.undo_stack.pop() raise HTTPException(404, "Track not found") # Get track name for history label track_name = f"Track {track_index + 1}" for msg in target_track: if msg.type == 'track_name': track_name = msg.name break label_parts = [] if request.channel is not None: new_channel = request.channel - 1 old_channels = set() for msg in target_track: if hasattr(msg, 'channel'): old_channels.add(msg.channel + 1) old_ch_str = ",".join(str(c) for c in sorted(old_channels)) if old_channels else "?" for msg in target_track: if hasattr(msg, 'channel'): msg.channel = new_channel label_parts.append(f"CH {old_ch_str} \u2192 {request.channel}") if request.program is not None: instrument_name = get_instrument_name(request.program) found_pc = False for msg in target_track: if msg.type == 'program_change': msg.program = request.program found_pc = True break if not found_pc: ch = request.channel - 1 if request.channel else 0 for msg in target_track: if hasattr(msg, 'channel'): ch = msg.channel break pc_msg = mido.Message('program_change', program=request.program, channel=ch, time=0) insert_idx = 0 for j, msg in enumerate(target_track): if msg.is_meta: insert_idx = j + 1 else: break target_track.insert(insert_idx, pc_msg) label_parts.append(f"Instrument \u2192 {instrument_name}") label = f"{track_name}: {', '.join(label_parts)}" session.midi_bytes = midi_to_bytes(midi) session.history.append(label) analysis = _analyze_session(session) return { "analysis": analysis, "history": session.history } @router.post("/{session_id}/track/{track_index}/delete") async def delete_track(session_id: str, track_index: int): session = sessions.get(session_id) if not session: raise HTTPException(404, "Session not found") session.undo_stack.append(session.midi_bytes) try: midi = load_midi_from_bytes(session.midi_bytes) except Exception as e: session.undo_stack.pop() raise HTTPException(500, f"Failed to load MIDI: {e}") raw_idx, target_track = _find_musical_track(midi, track_index) if target_track is None: session.undo_stack.pop() raise HTTPException(404, "Track not found") track_name = f"Track {track_index + 1}" for msg in target_track: if msg.type == 'track_name': track_name = msg.name break midi.tracks.pop(raw_idx) label = f"Delete {track_name}" session.midi_bytes = midi_to_bytes(midi) session.history.append(label) analysis = _analyze_session(session) return { "analysis": analysis, "history": session.history } @router.post("/{session_id}/merge") async def merge_tracks(session_id: str, request: MergeRequest): session = sessions.get(session_id) if not session: raise HTTPException(404, "Session not found") if len(request.tracks) < 2: raise HTTPException(400, "At least 2 tracks required for merge") session.undo_stack.append(session.midi_bytes) try: midi = load_midi_from_bytes(session.midi_bytes) except Exception as e: session.undo_stack.pop() raise HTTPException(500, f"Failed to load MIDI: {e}") # Collect raw indices and track references raw_indices = [] first_track_name = None for musical_idx in sorted(request.tracks): raw_idx, track = _find_musical_track(midi, musical_idx) if track is None: session.undo_stack.pop() raise HTTPException(404, f"Track {musical_idx} not found") raw_indices.append(raw_idx) if first_track_name is None: for msg in track: if msg.type == 'track_name': first_track_name = msg.name break if first_track_name is None: first_track_name = "Merged" # Collect all events with absolute timing all_events = [] for raw_idx in raw_indices: track = midi.tracks[raw_idx] absolute_time = 0 for msg in track: absolute_time += msg.time if msg.type == 'track_name': continue # Skip track names from source tracks all_events.append((absolute_time, msg.copy(time=0))) # Sort by absolute time and convert to delta all_events.sort(key=lambda x: x[0]) merged_track = mido.MidiTrack() merged_track.append(mido.MetaMessage('track_name', name=first_track_name, time=0)) prev_time = 0 for abs_time, msg in all_events: msg.time = abs_time - prev_time merged_track.append(msg) prev_time = abs_time # Remove old tracks in reverse order, then insert merged for raw_idx in sorted(raw_indices, reverse=True): midi.tracks.pop(raw_idx) insert_pos = min(raw_indices) midi.tracks.insert(insert_pos, merged_track) track_nums = ", ".join(str(t + 1) for t in sorted(request.tracks)) label = f"Merge Tracks {track_nums}" session.midi_bytes = midi_to_bytes(midi) session.history.append(label) analysis = _analyze_session(session) return { "analysis": analysis, "history": session.history } def _tool_label(request: ApplyRequest) -> str: names = { "baketempo": "Bake Tempo", "monofy": "Monofy", "reduncheck": "Remove Redundancy", "velfix": "Velocity Fix", "type0": "Convert to Type 0" } label = names.get(request.tool, request.tool) parts = [] if request.channels: parts.append(f"CH {','.join(str(c) for c in sorted(request.channels))}") if request.tracks is not None: parts.append(f"Tracks {','.join(str(t + 1) for t in sorted(request.tracks))}") if request.vel_min is not None and request.vel_max is not None: parts.append(f"vel={request.vel_min}-{request.vel_max}") if parts: label += f" ({'; '.join(parts)})" return label