You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
469 lines
16 KiB
469 lines
16 KiB
"""Utilities for scheduled parameter updates and learning-rate scheduling.
|
|
|
|
Includes object-path navigation for dynamically accessing/mutating nested
|
|
config attributes, a WarmupCosineScheduler for LR with linear warm-up and
|
|
cosine decay, and helpers for managing parameter change schedules.
|
|
"""
|
|
|
|
import numpy
|
|
import torch
|
|
import math
|
|
import re
|
|
from torch.optim import Optimizer
|
|
from torch.optim.lr_scheduler import _LRScheduler
|
|
from omegaconf.dictconfig import DictConfig
|
|
|
|
|
|
def _navigate_object_path(obj, path, split_char="@"):
|
|
"""
|
|
Navigate through a complex object path that may include:
|
|
- Attribute access: obj.attr
|
|
- Function calls: obj.method('param')
|
|
- Dictionary/array access: obj['key'][0]
|
|
- Mixed combinations: obj.method('param')['key'][0].attr
|
|
"""
|
|
current_obj = obj
|
|
|
|
# Split the path by the split_char and process each segment
|
|
segments = path.split(split_char)
|
|
|
|
for segment in segments:
|
|
current_obj = _process_path_segment(current_obj, segment)
|
|
|
|
return current_obj
|
|
|
|
|
|
def _process_path_segment(obj, segment):
|
|
"""
|
|
Process a single path segment that may contain:
|
|
- Simple attribute: attr
|
|
- Function call: method('param')
|
|
- Bracket access: ['key'][0]
|
|
- Combined: method('param')['key']
|
|
"""
|
|
current_obj = obj
|
|
|
|
# Parse the segment to identify different access patterns
|
|
i = 0
|
|
while i < len(segment):
|
|
if segment[i] == "[":
|
|
# Handle bracket access
|
|
bracket_end = _find_matching_bracket(segment, i)
|
|
bracket_content = segment[i + 1 : bracket_end]
|
|
|
|
# Evaluate the bracket content
|
|
if bracket_content.startswith("'") and bracket_content.endswith("'"):
|
|
# String key
|
|
key = bracket_content[1:-1]
|
|
current_obj = current_obj[key]
|
|
elif bracket_content.startswith('"') and bracket_content.endswith('"'):
|
|
# String key with double quotes
|
|
key = bracket_content[1:-1]
|
|
current_obj = current_obj[key]
|
|
elif bracket_content.lstrip("-").isdigit():
|
|
# Numeric index
|
|
index = int(bracket_content)
|
|
current_obj = current_obj[index]
|
|
else:
|
|
# Try to evaluate as expression (for complex keys)
|
|
try:
|
|
key = eval(bracket_content)
|
|
current_obj = current_obj[key]
|
|
except:
|
|
# Fallback to string key
|
|
current_obj = current_obj[bracket_content]
|
|
|
|
i = bracket_end + 1
|
|
|
|
else:
|
|
# Handle attribute access or function call
|
|
attr_start = i
|
|
# Find the end of the identifier (attribute or method name)
|
|
while i < len(segment) and (segment[i].isalnum() or segment[i] == "_"):
|
|
i += 1
|
|
|
|
if attr_start < i:
|
|
attr_name = segment[attr_start:i]
|
|
|
|
# Check if this is followed by parentheses (function call)
|
|
if i < len(segment) and segment[i] == "(":
|
|
# This is a function call
|
|
paren_end = _find_matching_paren(segment, i)
|
|
args_str = segment[i + 1 : paren_end]
|
|
|
|
# Parse and evaluate arguments
|
|
args = _parse_function_args(args_str)
|
|
|
|
# Call the method
|
|
method = getattr(current_obj, attr_name)
|
|
current_obj = method(*args)
|
|
|
|
i = paren_end + 1
|
|
else:
|
|
# This is a simple attribute access
|
|
if attr_name.lstrip("-").isdigit():
|
|
# Numeric index for direct access
|
|
current_obj = current_obj[int(attr_name)]
|
|
else:
|
|
# Attribute access
|
|
current_obj = getattr(current_obj, attr_name)
|
|
else:
|
|
# Skip non-alphanumeric characters that aren't brackets or parentheses
|
|
i += 1
|
|
|
|
return current_obj
|
|
|
|
|
|
def _find_matching_bracket(s, start):
|
|
"""Find the matching closing bracket for an opening bracket at position start."""
|
|
count = 1
|
|
i = start + 1
|
|
while i < len(s) and count > 0:
|
|
if s[i] == "[":
|
|
count += 1
|
|
elif s[i] == "]":
|
|
count -= 1
|
|
i += 1
|
|
return i - 1
|
|
|
|
|
|
def _find_matching_paren(s, start):
|
|
"""Find the matching closing parenthesis for an opening parenthesis at position start."""
|
|
count = 1
|
|
i = start + 1
|
|
while i < len(s) and count > 0:
|
|
if s[i] == "(":
|
|
count += 1
|
|
elif s[i] == ")":
|
|
count -= 1
|
|
i += 1
|
|
return i - 1
|
|
|
|
|
|
def _parse_function_args(args_str):
|
|
"""Parse function arguments from a string."""
|
|
if not args_str.strip():
|
|
return []
|
|
|
|
args = []
|
|
current_arg = ""
|
|
paren_count = 0
|
|
bracket_count = 0
|
|
in_quotes = False
|
|
quote_char = None
|
|
|
|
for char in args_str:
|
|
if char in ['"', "'"] and not in_quotes:
|
|
in_quotes = True
|
|
quote_char = char
|
|
current_arg += char
|
|
elif char == quote_char and in_quotes:
|
|
in_quotes = False
|
|
quote_char = None
|
|
current_arg += char
|
|
elif not in_quotes:
|
|
if char == "(":
|
|
paren_count += 1
|
|
current_arg += char
|
|
elif char == ")":
|
|
paren_count -= 1
|
|
current_arg += char
|
|
elif char == "[":
|
|
bracket_count += 1
|
|
current_arg += char
|
|
elif char == "]":
|
|
bracket_count -= 1
|
|
current_arg += char
|
|
elif char == "," and paren_count == 0 and bracket_count == 0:
|
|
args.append(_evaluate_arg(current_arg.strip()))
|
|
current_arg = ""
|
|
else:
|
|
current_arg += char
|
|
else:
|
|
current_arg += char
|
|
|
|
if current_arg.strip():
|
|
args.append(_evaluate_arg(current_arg.strip()))
|
|
|
|
return args
|
|
|
|
|
|
def _evaluate_arg(arg_str):
|
|
"""Evaluate a function argument string to its proper type."""
|
|
arg_str = arg_str.strip()
|
|
|
|
# String literals
|
|
if (arg_str.startswith("'") and arg_str.endswith("'")) or (
|
|
arg_str.startswith('"') and arg_str.endswith('"')
|
|
):
|
|
return arg_str[1:-1]
|
|
|
|
# Numeric literals
|
|
if arg_str.lstrip("-").replace(".", "").isdigit():
|
|
if "." in arg_str:
|
|
return float(arg_str)
|
|
else:
|
|
return int(arg_str)
|
|
|
|
# Boolean literals
|
|
if arg_str.lower() == "true":
|
|
return True
|
|
elif arg_str.lower() == "false":
|
|
return False
|
|
elif arg_str.lower() == "none":
|
|
return None
|
|
|
|
# For complex expressions, try eval (be careful in production)
|
|
try:
|
|
return eval(arg_str)
|
|
except:
|
|
# Fallback to string
|
|
return arg_str
|
|
|
|
|
|
def _get_final_target(obj, target_attr):
|
|
"""Get the final target object for reading, handling complex paths."""
|
|
if _is_complex_path(target_attr):
|
|
return _process_path_segment(obj, target_attr)
|
|
else:
|
|
# Simple attribute or numeric index
|
|
if target_attr.lstrip("-").isdigit():
|
|
return obj[int(target_attr)]
|
|
else:
|
|
return getattr(obj, target_attr)
|
|
|
|
|
|
def _set_final_target(obj, target_attr, value):
|
|
"""Set the final target value, handling complex paths."""
|
|
if _is_complex_path(target_attr):
|
|
# For complex paths, we need to navigate to the parent and set the final element
|
|
_set_complex_path_value(obj, target_attr, value)
|
|
else:
|
|
# Simple attribute or numeric index
|
|
if target_attr.lstrip("-").isdigit():
|
|
obj[int(target_attr)] = value
|
|
else:
|
|
setattr(obj, target_attr, value)
|
|
|
|
|
|
def _is_complex_path(path):
|
|
"""Check if a path contains complex access patterns (brackets or parentheses)."""
|
|
return "[" in path or "(" in path
|
|
|
|
|
|
def _set_complex_path_value(obj, path, value):
|
|
"""Set a value using a complex path by navigating to the parent and setting the final element."""
|
|
# Parse the path to find the parent path and final accessor
|
|
parent_obj = obj
|
|
|
|
# Find the last bracket or the final attribute
|
|
last_bracket = path.rfind("[")
|
|
last_paren = path.rfind("(")
|
|
|
|
if last_bracket > last_paren:
|
|
# Last accessor is a bracket
|
|
bracket_end = _find_matching_bracket(path, last_bracket)
|
|
parent_path = path[:last_bracket]
|
|
bracket_content = path[last_bracket + 1 : bracket_end]
|
|
|
|
if parent_path:
|
|
parent_obj = _process_path_segment(obj, parent_path)
|
|
|
|
# Set the value using bracket access
|
|
if bracket_content.startswith("'") and bracket_content.endswith("'"):
|
|
key = bracket_content[1:-1]
|
|
parent_obj[key] = value
|
|
elif bracket_content.startswith('"') and bracket_content.endswith('"'):
|
|
key = bracket_content[1:-1]
|
|
parent_obj[key] = value
|
|
elif bracket_content.lstrip("-").isdigit():
|
|
index = int(bracket_content)
|
|
parent_obj[index] = value
|
|
else:
|
|
try:
|
|
key = eval(bracket_content)
|
|
parent_obj[key] = value
|
|
except:
|
|
parent_obj[bracket_content] = value
|
|
else:
|
|
# No brackets, treat as simple attribute
|
|
if path.lstrip("-").isdigit():
|
|
obj[int(path)] = value
|
|
else:
|
|
setattr(obj, path, value)
|
|
|
|
|
|
def update_scheduled_params(obj, scheduler_dict, step, split_char="@"):
|
|
scheduled_params_dict = {}
|
|
for target, cfg in scheduler_dict.items():
|
|
sch_type = cfg["type"]
|
|
val_type = cfg.get("val_type", "float")
|
|
target_attr = target
|
|
target_obj = obj
|
|
if split_char in target:
|
|
target_obj_str, target_attr = target.rsplit(split_char, 1)
|
|
target_obj = _navigate_object_path(obj, target_obj_str, split_char)
|
|
if sch_type == "linear":
|
|
i = len(cfg["seg_vals"]) - 1
|
|
while step < cfg["seg_steps"][i]:
|
|
i -= 1
|
|
if i == len(cfg["seg_vals"]) - 1:
|
|
val = cfg["seg_vals"][i]
|
|
else:
|
|
t = (step - cfg["seg_steps"][i]) / (cfg["seg_steps"][i + 1] - cfg["seg_steps"][i])
|
|
t = max(0.0, min(1.0, t))
|
|
val = (1.0 - t) * cfg["seg_vals"][i] + t * cfg["seg_vals"][i + 1]
|
|
elif sch_type == "segment":
|
|
i = len(cfg["seg_vals"]) - 1
|
|
while step < cfg["seg_steps"][i]:
|
|
i -= 1
|
|
val = cfg["seg_vals"][i]
|
|
|
|
val = eval(val_type)(val)
|
|
|
|
if type(val) is DictConfig or type(val) is dict:
|
|
# Handle complex path for dict/config access
|
|
tmp_obj = _get_final_target(target_obj, target_attr)
|
|
|
|
if cfg.get("overwrite_dict", False):
|
|
_set_final_target(target_obj, target_attr, val)
|
|
else:
|
|
for k, v in val.items():
|
|
if type(tmp_obj) is dict:
|
|
tmp_obj[k] = v
|
|
else:
|
|
setattr(tmp_obj, k, v)
|
|
else:
|
|
# Handle complex path for direct value assignment
|
|
_set_final_target(target_obj, target_attr, val)
|
|
|
|
scheduled_params_dict[target] = val
|
|
|
|
if "trigger_func" in cfg and step == cfg["seg_steps"][i]:
|
|
target_func = cfg["trigger_func"]
|
|
print(f"Triggering function: {target_func}")
|
|
if split_char in target_func:
|
|
target_obj_str, target_func_name = target_func.rsplit(split_char, 1)
|
|
target_obj = _navigate_object_path(obj, target_obj_str, split_char)
|
|
else:
|
|
target_obj = obj
|
|
target_func_name = target_func
|
|
getattr(target_obj, target_func_name)()
|
|
|
|
return scheduled_params_dict
|
|
|
|
|
|
class WarmupCosineScheduler(_LRScheduler):
|
|
def __init__(
|
|
self,
|
|
optimizer: Optimizer,
|
|
num_warmup_steps: int,
|
|
num_training_steps: int,
|
|
final_lr: float = 0.0,
|
|
last_epoch: int = -1,
|
|
):
|
|
self.num_warmup_steps = num_warmup_steps
|
|
self.num_training_steps = num_training_steps
|
|
self.final_lr = final_lr
|
|
super(WarmupCosineScheduler, self).__init__(optimizer, last_epoch)
|
|
|
|
def get_lr(self):
|
|
current_step = self.last_epoch
|
|
if current_step < self.num_warmup_steps:
|
|
return [
|
|
base_lr * float(current_step) / float(max(1, self.num_warmup_steps))
|
|
for base_lr in self.base_lrs
|
|
]
|
|
else:
|
|
progress = float(current_step - self.num_warmup_steps) / float(
|
|
max(1, self.num_training_steps - self.num_warmup_steps)
|
|
)
|
|
cosine_decay = 0.5 * (1.0 + math.cos(math.pi * min(progress, 1.0)))
|
|
return [
|
|
self.final_lr + (base_lr - self.final_lr) * cosine_decay
|
|
for base_lr in self.base_lrs
|
|
]
|
|
|
|
|
|
if __name__ == "__main__":
|
|
# Test the complex path navigation
|
|
class MockEventManager:
|
|
def __init__(self):
|
|
self.configs = {
|
|
"push_robot": {"params": {"velocity_range": {"x": [1.0, 2.0], "y": [0.5, 1.5]}}}
|
|
}
|
|
|
|
def get_term_cfg(self, term_name):
|
|
return self.configs[term_name]
|
|
|
|
class MockEnv:
|
|
def __init__(self):
|
|
self.event_manager = MockEventManager()
|
|
|
|
class MockSimulator:
|
|
def __init__(self):
|
|
self.env = MockEnv()
|
|
|
|
# Test complex path navigation
|
|
mock_obj = MockSimulator()
|
|
|
|
# Test the path: env@event_manager@get_term_cfg('push_robot')@params@velocity_range@x@0
|
|
test_path = "env@event_manager@get_term_cfg('push_robot')['params']['velocity_range']['x'][0]"
|
|
|
|
# Create a simple scheduler config to test
|
|
scheduler_config = {
|
|
test_path: {"type": "linear", "seg_steps": [0, 100], "seg_vals": [5.0, 10.0]}
|
|
}
|
|
|
|
# Test the function
|
|
print("Testing complex path navigation...")
|
|
print(
|
|
f"Original value: {mock_obj.env.event_manager.get_term_cfg('push_robot')['params']['velocity_range']['x'][0]}"
|
|
)
|
|
|
|
result = update_scheduled_params(mock_obj, scheduler_config, 50)
|
|
print(
|
|
f"Updated value: {mock_obj.env.event_manager.get_term_cfg('push_robot')['params']['velocity_range']['x'][0]}"
|
|
)
|
|
print(f"Scheduler result: {result}")
|
|
|
|
# Test with step that triggers second segment
|
|
result2 = update_scheduled_params(mock_obj, scheduler_config, 150)
|
|
print(
|
|
f"Updated value (step 150): {mock_obj.env.event_manager.get_term_cfg('push_robot')['params']['velocity_range']['x'][0]}"
|
|
)
|
|
print(f"Scheduler result: {result2}")
|
|
|
|
print("\nOriginal learning rate scheduler test:")
|
|
|
|
class YourModel(torch.nn.Module):
|
|
def __init__(self):
|
|
super(YourModel, self).__init__()
|
|
self.fc = torch.nn.Linear(10, 1)
|
|
|
|
def forward(self, x):
|
|
return self.fc(x)
|
|
|
|
model = YourModel()
|
|
optimizer = torch.optim.AdamW(model.parameters(), lr=0.001)
|
|
|
|
num_warmup_steps = 1000
|
|
num_training_steps = 10000
|
|
final_lr = 0.0001
|
|
|
|
scheduler = WarmupCosineScheduler(optimizer, num_warmup_steps, num_training_steps, final_lr)
|
|
|
|
lrs = []
|
|
for step in range(num_training_steps):
|
|
scheduler.step()
|
|
lrs.append(scheduler.get_lr()[0])
|
|
|
|
# Plotting the learning rate vs training steps
|
|
import matplotlib.pyplot as plt
|
|
|
|
plt.plot(range(num_training_steps), lrs)
|
|
plt.xlabel("Training Steps")
|
|
plt.ylabel("Learning Rate")
|
|
plt.title("Learning Rate vs Training Steps")
|
|
# plt.show()
|
|
plt.savefig("out/lr_vs_steps.png")
|