def configure_sta(mode: str = 'STA_searching',
layer_num: int = 40,
time_step_num: int = 50,
head_num: int = 40,
**kwargs) -> list[list[list[Any]]]:
"""
Configure Sliding Tile Attention (STA) parameters based on the specified mode.
Parameters:
----------
mode : str
The STA mode to use. Options are:
- 'STA_searching': Generate a set of mask candidates for initial search
- 'STA_tuning': Select best mask strategy based on previously saved results
- 'STA_inference': Load and use a previously tuned mask strategy
layer_num: int, number of layers
time_step_num: int, number of timesteps
head_num: int, number of heads
**kwargs : dict
Mode-specific parameters:
For 'STA_searching':
- mask_candidates: list of str, optional, mask candidates to use
- mask_selected: list of int, optional, indices of selected masks
For 'STA_tuning':
- mask_search_files_path: str, required, path to mask search results
- mask_candidates: list of str, optional, mask candidates to use
- mask_selected: list of int, optional, indices of selected masks
- skip_time_steps: int, optional, number of time steps to use full attention (default 12)
- save_dir: str, optional, directory to save mask strategy (default "mask_candidates")
For 'STA_inference':
- load_path: str, optional, path to load mask strategy (default "mask_candidates/mask_strategy.json")
"""
valid_modes = [
'STA_searching', 'STA_tuning', 'STA_inference', 'STA_tuning_cfg'
]
if mode not in valid_modes:
raise ValueError(f"Mode must be one of {valid_modes}, got {mode}")
if mode == 'STA_searching':
# Get parameters with defaults
mask_candidates: list[str] | None = kwargs.get('mask_candidates')
if mask_candidates is None:
raise ValueError(
"mask_candidates is required for STA_searching mode")
mask_selected: list[int] = kwargs.get('mask_selected',
list(range(len(mask_candidates))))
# Parse selected masks
selected_masks: list[list[int]] = []
for index in mask_selected:
mask = mask_candidates[index]
masks_list = [int(x) for x in mask.split(',')]
selected_masks.append(masks_list)
# Create 3D mask structure with fixed dimensions (t=50, l=60)
masks_3d: list[list[list[list[int]]]] = []
for i in range(time_step_num): # Fixed t dimension = 50
row = []
for j in range(layer_num): # Fixed l dimension = 60
row.append(selected_masks) # Add all masks at each position
masks_3d.append(row)
return masks_3d
elif mode == 'STA_tuning':
# Get required parameters
mask_search_files_path: str | None = kwargs.get(
'mask_search_files_path')
if not mask_search_files_path:
raise ValueError(
"mask_search_files_path is required for STA_tuning mode")
# Get optional parameters with defaults
mask_candidates_tuning: list[str] | None = kwargs.get('mask_candidates')
if mask_candidates_tuning is None:
raise ValueError("mask_candidates is required for STA_tuning mode")
mask_selected_tuning: list[int] = kwargs.get(
'mask_selected', list(range(len(mask_candidates_tuning))))
skip_time_steps_tuning: int | None = kwargs.get('skip_time_steps')
save_dir_tuning: str | None = kwargs.get('save_dir', "mask_candidates")
# Parse selected masks
selected_masks_tuning: list[list[int]] = []
for index in mask_selected_tuning:
mask = mask_candidates_tuning[index]
masks_list = [int(x) for x in mask.split(',')]
selected_masks_tuning.append(masks_list)
# Read JSON results
results = read_specific_json_files(mask_search_files_path)
averaged_results = average_head_losses(results, selected_masks_tuning)
# Add full attention mask for specific cases
full_attention_mask_tuning: list[int] | None = kwargs.get(
'full_attention_mask')
if full_attention_mask_tuning is not None:
selected_masks_tuning.append(full_attention_mask_tuning)
# Select best mask strategy
timesteps_tuning: int = kwargs.get('timesteps', time_step_num)
if skip_time_steps_tuning is None:
skip_time_steps_tuning = 12
mask_strategy, sparsity, strategy_counts = select_best_mask_strategy(
averaged_results, selected_masks_tuning, skip_time_steps_tuning,
timesteps_tuning, head_num)
# Save mask strategy
if save_dir_tuning is not None:
os.makedirs(save_dir_tuning, exist_ok=True)
file_path = os.path.join(
save_dir_tuning,
f'mask_strategy_s{skip_time_steps_tuning}.json')
with open(file_path, 'w') as f:
json.dump(mask_strategy, f, indent=4)
print(f"Successfully saved mask_strategy to {file_path}")
# Print sparsity and strategy counts for information
print(f"Overall sparsity: {sparsity:.4f}")
print("\nStrategy usage counts:")
total_heads = time_step_num * layer_num * head_num # Fixed dimensions
for strategy, count in strategy_counts.items():
print(
f"Strategy {strategy}: {count} heads ({count/total_heads*100:.2f}%)"
)
# Convert dictionary to 3D list with fixed dimensions
mask_strategy_3d = dict_to_3d_list(mask_strategy,
t_max=time_step_num,
l_max=layer_num,
h_max=head_num)
return mask_strategy_3d
elif mode == 'STA_tuning_cfg':
# Get required parameters for both positive and negative paths
mask_search_files_path_pos: str | None = kwargs.get(
'mask_search_files_path_pos')
mask_search_files_path_neg: str | None = kwargs.get(
'mask_search_files_path_neg')
save_dir_cfg: str | None = kwargs.get('save_dir')
if not mask_search_files_path_pos or not mask_search_files_path_neg or not save_dir_cfg:
raise ValueError(
"mask_search_files_path_pos, mask_search_files_path_neg, and save_dir are required for STA_tuning_cfg mode"
)
# Get optional parameters with defaults
mask_candidates_cfg: list[str] | None = kwargs.get('mask_candidates')
if mask_candidates_cfg is None:
raise ValueError(
"mask_candidates is required for STA_tuning_cfg mode")
mask_selected_cfg: list[int] = kwargs.get(
'mask_selected', list(range(len(mask_candidates_cfg))))
skip_time_steps_cfg: int | None = kwargs.get('skip_time_steps')
# Parse selected masks
selected_masks_cfg: list[list[int]] = []
for index in mask_selected_cfg:
mask = mask_candidates_cfg[index]
masks_list = [int(x) for x in mask.split(',')]
selected_masks_cfg.append(masks_list)
# Read JSON results for both positive and negative paths
pos_results = read_specific_json_files(mask_search_files_path_pos)
neg_results = read_specific_json_files(mask_search_files_path_neg)
# Combine positive and negative results into one list
combined_results = pos_results + neg_results
# Average the combined results
averaged_results = average_head_losses(combined_results,
selected_masks_cfg)
# Add full attention mask for specific cases
full_attention_mask_cfg: list[int] | None = kwargs.get(
'full_attention_mask')
if full_attention_mask_cfg is not None:
selected_masks_cfg.append(full_attention_mask_cfg)
timesteps_cfg: int = kwargs.get('timesteps', time_step_num)
if skip_time_steps_cfg is None:
skip_time_steps_cfg = 12
# Select best mask strategy using combined results
mask_strategy, sparsity, strategy_counts = select_best_mask_strategy(
averaged_results, selected_masks_cfg, skip_time_steps_cfg,
timesteps_cfg, head_num)
# Save mask strategy
os.makedirs(save_dir_cfg, exist_ok=True)
file_path = os.path.join(save_dir_cfg,
f'mask_strategy_s{skip_time_steps_cfg}.json')
with open(file_path, 'w') as f:
json.dump(mask_strategy, f, indent=4)
print(f"Successfully saved mask_strategy to {file_path}")
# Print sparsity and strategy counts for information
print(f"Overall sparsity: {sparsity:.4f}")
print("\nStrategy usage counts:")
total_heads = time_step_num * layer_num * head_num # Fixed dimensions
for strategy, count in strategy_counts.items():
print(
f"Strategy {strategy}: {count} heads ({count/total_heads*100:.2f}%)"
)
# Convert dictionary to 3D list with fixed dimensions
mask_strategy_3d = dict_to_3d_list(mask_strategy,
t_max=time_step_num,
l_max=layer_num,
h_max=head_num)
return mask_strategy_3d
else: # STA_inference
# Get parameters with defaults
load_path: str | None = kwargs.get(
'load_path', "mask_candidates/mask_strategy.json")
if load_path is None:
raise ValueError("load_path is required for STA_inference mode")
# Load previously saved mask strategy
with open(load_path) as f:
mask_strategy = json.load(f)
# Convert dictionary to 3D list with fixed dimensions
mask_strategy_3d = dict_to_3d_list(mask_strategy,
t_max=time_step_num,
l_max=layer_num,
h_max=head_num)
return mask_strategy_3d