Skip to main content
Glama
ReexpressAI

Reexpress MCP Server

Official
by ReexpressAI
utils_train_main.py24.7 kB
# Copyright Reexpress AI, Inc. All rights reserved. from sdm_model import SimilarityDistanceMagnitudeCalibrator import constants import utils_model import utils_pretraining_initialization import utils_calibrate import utils_eval_batch import torch import torch.optim as optim import torch.nn as nn import numpy as np import logging import sys import time logger = logging.getLogger(__name__) logger.setLevel(logging.DEBUG) logger.addHandler(logging.StreamHandler(sys.stdout)) def print_timing_to_standard_out(label, elapsed_time, print_timing=False): if print_timing and elapsed_time is not None: print(f"[TIMING] (seconds): {label}: {time.time() - elapsed_time}") def train(options, train_embeddings=None, calibration_embeddings=None, train_labels=None, calibration_labels=None, model_params=None, main_device=None, model_dir=None, model=None, shuffle_index=0): print(f"Training on main device: {main_device}") current_device = main_device if model is None: print("Initializing model") model = SimilarityDistanceMagnitudeCalibrator(**model_params).to(current_device) else: model = model.to(current_device) if options.pretraining_initialization_epochs > 0: # The train and calibration data for the iterative shuffle and SDM loss # (i.e., train_embeddings and calibration_embeddings) are assumed to be disjoint from # options.pretraining_initialization_tensors_file held_out_embeddings = torch.cat([train_embeddings, calibration_embeddings], dim=0) # 0 is batch dim held_out_labels = torch.cat([train_labels, calibration_labels], dim=0) model = utils_pretraining_initialization.pretrain(options, model=model, model_dir=model_dir, held_out_embeddings=held_out_embeddings, held_out_labels=held_out_labels, ).to(current_device) del held_out_embeddings del held_out_labels if options.ood_support_file.strip() != "": import utils_preprocess import data_validator ood_support_meta_data, _ = \ utils_preprocess.get_metadata_lines(options, options.ood_support_file, reduce=False, use_embeddings=options.use_embeddings, concat_embeddings_to_attributes=options.concat_embeddings_to_attributes, calculate_summary_stats=False, is_training=False) ood_support_embeddings = ood_support_meta_data["embeddings"] ood_support_labels = ood_support_meta_data["labels"] label_parity_warning = False for ood_label in ood_support_labels: if ood_label != data_validator.oodLabel: label_parity_warning = True if label_parity_warning: print(f">>NOTE: Using --ood_support_file is primarily intended for adding OOD " f"(label=={data_validator.oodLabel}) instances to the training database. " f"You can add other instances to the support (as you are doing) using this " f"mechanism, but note that they will not participate in the iterative shuffling with the " f"calibration set. As such, typically documents with labels in [0, C) should instead be added to " f"--input_training_set_file or --input_calibration_set_file.<<") ood_support_document_ids = ood_support_meta_data["uuids"] print(f"Loaded {len(ood_support_labels)} OOD/additional documents to add to the training support set") ood_support_labels = torch.tensor(ood_support_labels) else: ood_support_meta_data = None train_size = train_embeddings.shape[0] print("Starting training") parameters = filter(lambda p: p.requires_grad, model.parameters()) optimizer = optim.Adam(parameters, lr=options.learning_rate, betas=(0.9, 0.999), eps=1e-08) criterion = nn.NLLLoss() max_dev_acc = 0 max_dev_acc_epoch = -1 train_acc_for_max_dev_acc = 0 max_dev_balanced_acc = 0 max_dev_balanced_acc_epoch = -1 train_balanced_acc_for_max_dev_acc = 0 max_dev_balanced_q = 0 max_dev_balanced_q_epoch = -1 train_balanced_q_for_max_dev_balanced_q = 0 min_dev_balanced_sdm_loss = np.inf min_dev_balanced_sdm_loss_epoch = -1 train_balanced_sdm_loss_for_min_dev_sdm_loss = np.inf all_epoch_cumulative_losses = [] batch_size = options.batch_size train_dataset_q_values = torch.zeros(train_embeddings.shape[0], 1) + (np.e - model.q_rescale_offset) train_dataset_distance_quantile_per_class = None for e in range(options.epoch): # shuffle data shuffled_train_indexes = torch.randperm(train_embeddings.shape[0]) shuffled_train_embeddings = train_embeddings[shuffled_train_indexes] shuffled_train_labels = train_labels[shuffled_train_indexes] shuffled_q = train_dataset_q_values[shuffled_train_indexes] if e == 0: # The initial epoch uses softmax (q=e-2, d=1). shuffled_distance_quantile_per_class = None else: shuffled_distance_quantile_per_class = train_dataset_distance_quantile_per_class[shuffled_train_indexes] batch_num = 0 cumulative_losses = [] single_epoch_time = time.time() if options.print_timing else None for i in range(0, train_size, batch_size): batch_num += 1 batch_range = min(batch_size, train_size - i) batch_x = shuffled_train_embeddings[i:i + batch_range].to(current_device) batch_y = shuffled_train_labels[i:i + batch_range].to(current_device) batch_q = shuffled_q[i:i + batch_range].to(current_device) if shuffled_distance_quantile_per_class is not None: batch_distance_quantile_per_class = \ shuffled_distance_quantile_per_class[i:i + batch_range].to(current_device) else: batch_distance_quantile_per_class = None optimizer.zero_grad() model.train() _, batch_log_sdm_output = model(batch_x, batch_q, batch_distance_quantile_per_class=batch_distance_quantile_per_class, forward_type=constants.FORWARD_TYPE_SENTENCE_LEVEL_PREDICTION, train=True) if len(batch_log_sdm_output.shape) == 1: loss = criterion(batch_log_sdm_output.unsqueeze(0), batch_y) else: loss = criterion(batch_log_sdm_output, batch_y) cumulative_losses.append(loss.item()) loss.backward() optimizer.step() print_timing_to_standard_out("Single epoch", single_epoch_time, print_timing=options.print_timing) print(f"---------------Shuffle Index {shuffle_index}: Epoch: {e + 1}---------------") print(f"Epoch average (marginal) cumulative loss: {np.mean(cumulative_losses)}") all_epoch_cumulative_losses.extend(cumulative_losses) print(f"Average (marginal) loss across all mini-batches (all epochs): {np.mean(all_epoch_cumulative_losses)}") support_stats_time = time.time() if options.print_timing else None # First get training set predictions and exemplar vectors. train_batch_f_positive_outputs, train_exemplar_vectors = \ utils_eval_batch.get_predictions_and_exemplar_vectors(options.eval_batch_size, model, train_embeddings, current_device, place_output_on_cpu=True) # MUST set support set predictions before calculating q, d0: model.set_train_predicted_labels(torch.argmax(train_batch_f_positive_outputs, dim=1)) calibration_batch_f_positive_outputs, calibration_exemplar_vectors = \ utils_eval_batch.get_predictions_and_exemplar_vectors(options.eval_batch_size, model, calibration_embeddings, current_device, place_output_on_cpu=True) model.set_calibration_predicted_labels(torch.argmax(calibration_batch_f_positive_outputs, dim=1)) if ood_support_meta_data is not None: ood_support_batch_f_positive_outputs, ood_support_exemplar_vectors = \ utils_eval_batch.get_predictions_and_exemplar_vectors( options.eval_batch_size, model, ood_support_embeddings, current_device, place_output_on_cpu=True) ood_support_predicted_labels = torch.argmax(ood_support_batch_f_positive_outputs, dim=1) _, calibration_top_k_distances, calibration_top_k_distances_idx = \ model.construct_support_index(support_exemplar_vectors_numpy=train_exemplar_vectors.numpy(), calibration_exemplar_vectors_numpy=calibration_exemplar_vectors.numpy(), ood_support_exemplar_vectors_numpy=ood_support_exemplar_vectors.numpy(), ood_support_labels=ood_support_labels, # tensor ood_support_predicted_labels=ood_support_predicted_labels, # tensor ood_support_document_ids=ood_support_document_ids ) else: # Set the exemplar vectors of training as the support set and fetch the calibration distances _, calibration_top_k_distances, calibration_top_k_distances_idx = \ model.construct_support_index(support_exemplar_vectors_numpy=train_exemplar_vectors.numpy(), calibration_exemplar_vectors_numpy=calibration_exemplar_vectors.numpy()) # Fetch the training distances. This will include the identity match, which is handled below. # Currently, we assume there are no duplicates in the data splits (or at least there are very few) train_top_k_distances__including_self, train_top_k_distances_idx__including_self = \ model.get_top_support_distances(train_exemplar_vectors.numpy()) # get q values and dCDF for training; is_training_support=True will discard the first (identity) match # Note that the distance quantiles for training are determined by distances over training. The class # attribute model.trueClass_To_dCDF is over calibration, which is what should be used for new, unseen # test instances. train_dataset_q_values, train_trueClass_To_dataset_total_q_ood, train_trueClass_To_total_labels, \ train_dataset_d0_values, train_trueClass_To_dCDF = model.set_summary_stats_for_support_vectorized( train_exemplar_vectors.shape[0], train_top_k_distances__including_self, train_top_k_distances_idx__including_self, train_batch_f_positive_outputs, train_labels, is_training_support=True) model.set_train_trueClass_To_dCDF(train_trueClass_To_dCDF) for class_i in range(model.numberOfClasses): if len(train_trueClass_To_dCDF[class_i]) > 0: print(f"\tDistances: {'Train'}: (class {class_i}) mean d0: {np.mean(train_trueClass_To_dCDF[class_i])}; " f"median d0: {np.median(train_trueClass_To_dCDF[class_i])}, " f"min: {np.min(train_trueClass_To_dCDF[class_i])}, " f"max: {np.max(train_trueClass_To_dCDF[class_i])}, " f"out of {len(train_trueClass_To_dCDF[class_i])}") else: print( f"\tDistances: {'Train'}: (class {class_i}): WARNING NO DISTANCES AVAILABLE") for class_i in range(model.numberOfClasses): print(f"\tTotal OOD q values (q<={model.ood_limit}): {'Train'}: (class {class_i}): " f"{train_trueClass_To_dataset_total_q_ood[class_i]} " f"out of {train_trueClass_To_total_labels[class_i]}: " f"{train_trueClass_To_dataset_total_q_ood[class_i]/(float(train_trueClass_To_total_labels[class_i]) if train_trueClass_To_total_labels[class_i] > 0 else 1.0)}") # get q values for calibration and set the class dCDF calibration_dataset_q_values, calibration_trueClass_To_dataset_total_q_ood, \ calibration_trueClass_To_total_labels, calibration_dataset_d0_values, _ = \ model.set_summary_stats_for_support_vectorized(calibration_exemplar_vectors.shape[0], calibration_top_k_distances, calibration_top_k_distances_idx, calibration_batch_f_positive_outputs, calibration_labels, is_training_support=False) for class_i in range(model.numberOfClasses): if len(model.trueClass_To_dCDF[class_i]) > 0: print(f"\tDistances: {constants.SPLIT_LABEL_calibration_during_training}: (class {class_i}) mean d0: " f"{np.mean(model.trueClass_To_dCDF[class_i])}; " f"median d0: {np.median(model.trueClass_To_dCDF[class_i])}, " f"min: {np.min(model.trueClass_To_dCDF[class_i])}, " f"max: {np.max(model.trueClass_To_dCDF[class_i])}, " f"out of {len(model.trueClass_To_dCDF[class_i])}") else: print( f"\tDistances: {constants.SPLIT_LABEL_calibration_during_training}: (class {class_i}): " f"WARNING NO DISTANCES AVAILABLE") for class_i in range(model.numberOfClasses): print(f"\tTotal OOD q values (q<={model.ood_limit}): {constants.SPLIT_LABEL_calibration_during_training}: (class {class_i}): " f"{calibration_trueClass_To_dataset_total_q_ood[class_i]} " f"out of {calibration_trueClass_To_total_labels[class_i]}: " f"{calibration_trueClass_To_dataset_total_q_ood[class_i]/(float(calibration_trueClass_To_total_labels[class_i]) if calibration_trueClass_To_total_labels[class_i] > 0 else 1.0)}") print_timing_to_standard_out("Construct exemplars, q, distances", support_stats_time, print_timing=options.print_timing) train_d_time = time.time() if options.print_timing else None # get training distance quantiles, using distance empirical CDF over training train_dataset_distance_quantile_per_class = \ model.get_distance_quantiles_vectorized(train_dataset_d0_values, train_trueClass_To_dCDF=train_trueClass_To_dCDF) print_timing_to_standard_out("Calculate training distance quantiles", train_d_time, print_timing=options.print_timing) calibration_d_time = time.time() if options.print_timing else None # get calibration training quantiles, using distance empirical CDF over calibration calibration_dataset_distance_quantile_per_class = \ model.get_distance_quantiles_vectorized(calibration_dataset_d0_values, train_trueClass_To_dCDF=None) # utils_unit_tests.run_unit_test_comparison_of_get_distance_quantiles(model, calibration_dataset_d0_values, # train_dataset_d0_values, train_trueClass_To_dCDF) print_timing_to_standard_out("Calculate calibration distance quantiles", calibration_d_time, print_timing=options.print_timing) eval_time = time.time() if options.print_timing else None # Calculate metrics from the cached output (the CNN is not rerun). # In the current version, we do not reset train_dataset_q_values for predictions flips resulting after # rescaling, since they are very rare and will be handled otherwise by low q and d. However, as with earlier # versions, the prediction is always determined by f rather than the sdm output. train_per_class_loss_as_list, train_balanced_loss, train_marginal_loss, \ train_per_class_accuracy_as_list, train_balanced_accuracy, train_marginal_accuracy, \ train_per_class_q_as_list, train_balanced_q, train_marginal_q, \ train_sdm_outputs = \ utils_eval_batch.get_metrics_from_cached_outputs(options.eval_batch_size, model, train_batch_f_positive_outputs, current_device, train_labels, q_values=train_dataset_q_values, distance_quantile_per_class=train_dataset_distance_quantile_per_class) calibration_per_class_loss_as_list, calibration_balanced_loss, calibration_marginal_loss, \ calibration_per_class_accuracy_as_list, calibration_balanced_accuracy, calibration_marginal_accuracy, \ calibration_per_class_q_as_list, calibration_balanced_q, calibration_marginal_q, \ calibration_sdm_outputs = \ utils_eval_batch.get_metrics_from_cached_outputs(options.eval_batch_size, model, calibration_batch_f_positive_outputs, current_device, calibration_labels, q_values=calibration_dataset_q_values, distance_quantile_per_class=calibration_dataset_distance_quantile_per_class) print_timing_to_standard_out("Calculate metrics", eval_time, print_timing=options.print_timing) time_to_set_similarity = time.time() if options.print_timing else None utils_calibrate.set_model_rescaled_similarity_vectorized(model, calibration_batch_f_positive_outputs, calibration_dataset_q_values, calibration_sdm_outputs) print_timing_to_standard_out("Set model rescaled Similarity", time_to_set_similarity, print_timing=options.print_timing) utils_eval_batch.print_metrics(e=e, numberOfClasses=model.numberOfClasses, split_label_name="Training set", per_class_loss_as_list=train_per_class_loss_as_list, balanced_loss=train_balanced_loss, marginal_loss=train_marginal_loss, per_class_accuracy_as_list=train_per_class_accuracy_as_list, balanced_accuracy=train_balanced_accuracy, marginal_accuracy=train_marginal_accuracy, per_class_q_as_list=train_per_class_q_as_list, balanced_q=train_balanced_q, marginal_q=train_marginal_q) utils_eval_batch.print_metrics(e=e, numberOfClasses=model.numberOfClasses, split_label_name="Calibration set", per_class_loss_as_list=calibration_per_class_loss_as_list, balanced_loss=calibration_balanced_loss, marginal_loss=calibration_marginal_loss, per_class_accuracy_as_list=calibration_per_class_accuracy_as_list, balanced_accuracy=calibration_balanced_accuracy, marginal_accuracy=calibration_marginal_accuracy, per_class_q_as_list=calibration_per_class_q_as_list, balanced_q=calibration_balanced_q, marginal_q=calibration_marginal_q) is_best_running_epoch = calibration_balanced_loss <= min_dev_balanced_sdm_loss if calibration_balanced_loss <= min_dev_balanced_sdm_loss: min_dev_balanced_sdm_loss = calibration_balanced_loss min_dev_balanced_sdm_loss_epoch = e + 1 train_balanced_sdm_loss_for_min_dev_sdm_loss = train_balanced_loss if calibration_marginal_accuracy >= max_dev_acc: max_dev_acc = calibration_marginal_accuracy max_dev_acc_epoch = e + 1 train_acc_for_max_dev_acc = train_marginal_accuracy if calibration_balanced_accuracy >= max_dev_balanced_acc: max_dev_balanced_acc = calibration_balanced_accuracy max_dev_balanced_acc_epoch = e + 1 train_balanced_acc_for_max_dev_acc = train_balanced_accuracy if calibration_balanced_q >= max_dev_balanced_q: max_dev_balanced_q = calibration_balanced_q max_dev_balanced_q_epoch = e + 1 train_balanced_q_for_max_dev_balanced_q = train_balanced_q if is_best_running_epoch: model.increment_model_calibration_training_stage(set_value=1) utils_model.save_model(model, model_dir) logger.info(f"Model saved at {model_dir} as best running epoch.") print(f"---Summary---") print(f"\tCurrent max Calibration set accuracy: {max_dev_acc} at epoch {max_dev_acc_epoch} " f"(corresponding Training set accuracy: {train_acc_for_max_dev_acc})") print(f"\tCurrent max Calibration set Balanced accuracy: {max_dev_balanced_acc} at epoch {max_dev_balanced_acc_epoch} " f"(corresponding Training set Balanced accuracy: {train_balanced_acc_for_max_dev_acc})") print(f"\tCurrent max Calibration set Balanced q: {max_dev_balanced_q} at epoch {max_dev_balanced_q_epoch} " f"(corresponding Training set Balanced q: {train_balanced_q_for_max_dev_balanced_q})") print(f"\tCurrent min Calibration set Balanced SDM loss: {min_dev_balanced_sdm_loss} at epoch {min_dev_balanced_sdm_loss_epoch} " f"(corresponding Training set Balanced SDM loss: {train_balanced_sdm_loss_for_min_dev_sdm_loss})") print(f"+++++++++++++++Shuffle Index {shuffle_index}: Training complete+++++++++++++++") print(f"\tMax Calibration set accuracy: {max_dev_acc} at epoch {max_dev_acc_epoch} " f"(corresponding Training set accuracy: {train_acc_for_max_dev_acc})") print( f"\tMax Calibration set Balanced accuracy: {max_dev_balanced_acc} at epoch {max_dev_balanced_acc_epoch} " f"(corresponding Training set Balanced accuracy: {train_balanced_acc_for_max_dev_acc})") print( f"\tMax Calibration set Balanced q: {max_dev_balanced_q} at epoch {max_dev_balanced_q_epoch} " f"(corresponding Training set Balanced q: {train_balanced_q_for_max_dev_balanced_q})") print( f"\tMin Calibration set Balanced SDM loss: {min_dev_balanced_sdm_loss} at epoch {min_dev_balanced_sdm_loss_epoch} " f"(corresponding Training set Balanced SDM loss: {train_balanced_sdm_loss_for_min_dev_sdm_loss})") print(f"Final epoch chosen based on the minimum Balanced SDM loss (over calibration).") print(f"Reloading best model to calibrate based on the provided alpha value.") min_rescaled_similarity_to_determine_high_reliability_region = \ utils_calibrate.calibrate_to_determine_high_reliability_region(options, model_dir=model_dir) return max_dev_balanced_acc, max_dev_balanced_q, min_dev_balanced_sdm_loss, \ min_rescaled_similarity_to_determine_high_reliability_region

Latest Blog Posts

MCP directory API

We provide all the information about MCP servers via our MCP API.

curl -X GET 'https://glama.ai/api/mcp/v1/servers/ReexpressAI/reexpress_mcp_server'

If you have feedback or need assistance with the MCP directory API, please join our Discord server