tune_loop(config=None)

Tune loop.

newsclassifier\tune.py
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
def tune_loop(config=None):
    """Tune loop."""
    # ====================================================
    # loader
    # ====================================================
    logger.info("Starting Tuning.")
    with wandb.init(project="NewsClassifier", config=config):
        config = wandb.config

        df = load_dataset(Cfg.dataset_loc)
        ds, headlines_df, class_to_index, index_to_class = preprocess(df)
        train_ds, val_ds = data_split(ds, test_size=Cfg.test_size)

        train_dataset = NewsDataset(train_ds)
        valid_dataset = NewsDataset(val_ds)

        train_loader = DataLoader(train_dataset, batch_size=config.batch_size, shuffle=True, num_workers=4, pin_memory=True, drop_last=True)
        valid_loader = DataLoader(valid_dataset, batch_size=config.batch_size, shuffle=False, num_workers=4, pin_memory=True, drop_last=False)

        # ====================================================
        # model
        # ====================================================
        num_classes = Cfg.num_classes
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

        model = CustomModel(num_classes=num_classes, dropout_pb=config.dropout_pb)
        model.to(device)

        # ====================================================
        # Training components
        # ====================================================
        criterion = nn.BCEWithLogitsLoss()
        optimizer = torch.optim.Adam(model.parameters(), lr=config.learning_rate)
        scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
            optimizer, mode="min", factor=config.lr_reduce_factor, patience=config.lr_reduce_patience
        )

        # ====================================================
        # loop
        # ====================================================
        wandb.watch(model, criterion, log="all", log_freq=10)

        for epoch in range(config.epochs):
            try:
                start_time = time.time()

                # Step
                train_loss = train_step(train_loader, model, num_classes, criterion, optimizer, epoch)
                val_loss, _, _ = eval_step(valid_loader, model, num_classes, criterion, epoch)
                scheduler.step(val_loss)

                # scoring
                elapsed = time.time() - start_time
                wandb.log({"epoch": epoch + 1, "train_loss": train_loss, "val_loss": val_loss})
                print(f"Epoch {epoch+1} - avg_train_loss: {train_loss:.4f}  avg_val_loss: {val_loss:.4f}  time: {elapsed:.0f}s")
            except Exception as e:
                logger.error(f"Epoch {epoch+1}, {e}")

        torch.cuda.empty_cache()
        gc.collect()