def inference() -> None:
"""Do inference prediction."""
logger.info("Loading inference data.")
try:
test_dataset = NewsDataset(os.path.join(Cfg.preprocessed_data_path, "test.csv"))
test_loader = DataLoader(test_dataset, batch_size=Cfg.batch_size, shuffle=False, num_workers=4, pin_memory=True, drop_last=False)
except Exception as e:
logger.error(e)
logger.info("loading model.")
try:
model = CustomModel(num_classes=Cfg.num_classes)
model.load_state_dict(torch.load(Cfg.model_path, map_location=torch.device("cpu")))
model.to(device)
except Exception as e:
logger.error(e)
y_true, y_pred = test_step(test_loader, model)
print(
f'Precision: {precision_score(y_true, y_pred, average="weighted")} \n Recall: {recall_score(y_true, y_pred, average="weighted")} \n F1: {f1_score(y_true, y_pred, average="weighted")} \n Accuracy: {accuracy_score(y_true, y_pred)}'
)