File size: 4,818 Bytes
6ab17a7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
#!/usr/bin/env python3
# /// script
# dependencies = []
# ///
"""
Estimate training time and cost for TRL jobs.

Usage:
    python estimate_cost.py --model <model> --dataset <dataset> --hardware <flavor>
    
Example:
    python estimate_cost.py --model Qwen/Qwen2.5-0.5B --dataset trl-lib/Capybara --hardware a10g-large
"""

import argparse

# Hardware costs per hour (approximate)
HARDWARE_COSTS = {
    "t4-small": 0.75,
    "t4-medium": 1.50,
    "l4x1": 2.50,
    "a10g-small": 3.50,
    "a10g-large": 5.00,
    "a10g-largex2": 10.00,
    "a10g-largex4": 20.00,
    "a100-large": 10.00,
}

# Model sizes in billions of parameters
MODEL_SIZES = {
    "0.5B": 0.5,
    "1.5B": 1.5,
    "3B": 3,
    "7B": 7,
    "13B": 13,
}

def estimate_training_time(model_params, dataset_size, epochs, hardware):
    """Estimate training time in hours."""
    # Rough estimates based on empirical observations
    # These are approximations and actual times will vary
    
    base_time_per_1k_examples = 0.1  # hours for 1B model on a10g-large
    
    # Adjust for model size
    time = base_time_per_1k_examples * model_params * (dataset_size / 1000) * epochs
    
    # Adjust for hardware (relative to a10g-large baseline)
    hardware_multipliers = {
        "t4-small": 2.0,
        "t4-medium": 1.5,
        "l4x1": 1.2,
        "a10g-small": 1.3,
        "a10g-large": 1.0,
        "a10g-largex2": 0.6,
        "a10g-largex4": 0.4,
        "a100-large": 0.7,
    }
    
    multiplier = hardware_multipliers.get(hardware, 1.0)
    time *= multiplier
    
    return time

def parse_args():
    parser = argparse.ArgumentParser(description="Estimate training cost for TRL jobs")
    parser.add_argument("--model", required=True, help="Model name or size (e.g., 'Qwen/Qwen2.5-0.5B' or '0.5B')")
    parser.add_argument("--dataset", required=True, help="Dataset name")
    parser.add_argument("--hardware", required=True, choices=HARDWARE_COSTS.keys(), help="Hardware flavor")
    parser.add_argument("--dataset-size", type=int, help="Override dataset size (number of examples)")
    parser.add_argument("--epochs", type=int, default=3, help="Number of training epochs")
    return parser.parse_args()

def extract_model_size(model_name):
    """Extract model size from name or return parsed value."""
    for size_str, size_val in MODEL_SIZES.items():
        if size_str in model_name:
            return size_val
    
    # Try to parse directly
    try:
        if "B" in model_name:
            return float(model_name.replace("B", ""))
    except:
        pass
    
    return 1.0  # Default to 1B if can't determine

def main():
    args = parse_args()
    
    # Extract model parameters
    model_params = extract_model_size(args.model)
    print(f"πŸ“Š Model: {args.model} (~{model_params}B parameters)")
    
    # Estimate dataset size (would need to load to get real size)
    if args.dataset_size:
        dataset_size = args.dataset_size
    else:
        # Common dataset sizes (approximations)
        dataset_sizes = {
            "trl-lib/Capybara": 16000,
            "Anthropic/hh-rlhf": 160000,
        }
        dataset_size = dataset_sizes.get(args.dataset, 10000)
    
    print(f"πŸ“¦ Dataset: {args.dataset} (~{dataset_size} examples)")
    print(f"πŸ”„ Epochs: {args.epochs}")
    print(f"πŸ’» Hardware: {args.hardware}")
    print()
    
    # Estimate training time
    estimated_hours = estimate_training_time(model_params, dataset_size, args.epochs, args.hardware)
    estimated_cost = estimated_hours * HARDWARE_COSTS[args.hardware]
    
    # Recommend timeout with buffer
    recommended_timeout_hours = estimated_hours * 1.3  # 30% buffer
    
    print(f"⏱️  Estimated training time: {estimated_hours:.1f} hours")
    print(f"πŸ’° Estimated cost: ${estimated_cost:.2f}")
    print(f"⏰ Recommended timeout: {recommended_timeout_hours:.1f}h (with 30% buffer)")
    print()
    
    # Warnings and recommendations
    if estimated_hours > 4:
        print("⚠️  Long training time - consider:")
        print("   - Using faster hardware")
        print("   - Reducing epochs")
        print("   - Using a smaller dataset subset for testing")
    
    if model_params >= 7 and args.hardware not in ["a10g-largex2", "a10g-largex4", "a100-large"]:
        print("⚠️  Large model - consider using:")
        print("   - Larger GPU (a100-large)")
        print("   - Multi-GPU setup (a10g-largex2 or a10g-largex4)")
        print("   - LoRA/PEFT for memory efficiency")
    
    print()
    print("πŸ“‹ Example job configuration:")
    print(f"""
hf_jobs("uv", {{
    "script": "your_training_script.py",
    "flavor": "{args.hardware}",
    "timeout": "{recommended_timeout_hours:.0f}h",
    "secrets": {{"HF_TOKEN": "$HF_TOKEN"}}
}})
""")

if __name__ == "__main__":
    main()