Skip to main content
Glama
ReexpressAI

Reexpress MCP Server

Official
by ReexpressAI
utils_pretraining_initialization.py10.7 kB
# Copyright Reexpress AI, Inc. All rights reserved. import constants import utils_model import torch import torch.optim as optim import torch.nn as nn import numpy as np def load_pretraining_initialization_tensors(pretraining_initialization_tensors_file): concatenated_embeddings_and_labels = torch.load( pretraining_initialization_tensors_file, weights_only=True, map_location=torch.device("cpu")) # concatenated_embeddings_and_labels is shape: [total_instances BY (embedding dimension + labels)] train_embeddings = concatenated_embeddings_and_labels[:, 0:-1] train_labels = concatenated_embeddings_and_labels[:, -1].long() return train_embeddings, train_labels def pretrain(options, model=None, model_dir=None, held_out_embeddings=None, held_out_labels=None, train_embeddings=None, train_labels=None, pretraining_learning_rate=None, return_min_held_out_balanced_loss=False, main_device=None, use_main_device=False): device_label = "main_device" if use_main_device: assert main_device is not None current_device = main_device else: device_label = "aux_device" current_device = torch.device(options.aux_device) if options.is_baseline_adaptor: total_epochs = options.epoch print(f"Training baseline CNN adaptor for {total_epochs} epochs on {current_device} ({device_label})") else: total_epochs = options.pretraining_initialization_epochs print(f"Pretraining initialization for {total_epochs} epochs on {current_device} ({device_label})") assert model is not None model = model.to(current_device) if train_embeddings is None: train_embeddings, train_labels = \ load_pretraining_initialization_tensors(options.pretraining_initialization_tensors_file) train_size = train_embeddings.shape[0] if pretraining_learning_rate is None: pretraining_learning_rate = options.pretraining_learning_rate print(f"Starting pretraining over {train_size} instances with LR={pretraining_learning_rate}") parameters = filter(lambda p: p.requires_grad, model.parameters()) optimizer = optim.Adam(parameters, lr=pretraining_learning_rate, betas=(0.9, 0.999), eps=1e-08) criterion = nn.NLLLoss() all_epoch_cumulative_losses = [] min_held_out_balanced_loss = np.inf min_held_out_balanced_loss_epoch = -1 best_model_conv_weight = None best_model_conv_bias = None best_model_fc_weight = None best_model_fc_bias = None batch_size = options.batch_size default_training_q_values = torch.zeros(train_embeddings.shape[0], 1) + (np.e - model.q_rescale_offset) for e in range(total_epochs): # shuffle data shuffled_train_indexes = torch.randperm(train_embeddings.shape[0]) shuffled_train_embeddings = train_embeddings[shuffled_train_indexes] shuffled_train_labels = train_labels[shuffled_train_indexes] shuffled_q = default_training_q_values[shuffled_train_indexes] batch_num = 0 cumulative_losses = [] for i in range(0, train_size, batch_size): batch_num += 1 batch_range = min(batch_size, train_size - i) batch_x = shuffled_train_embeddings[i:i + batch_range].to(current_device) batch_y = shuffled_train_labels[i:i + batch_range].to(current_device) batch_q = shuffled_q[i:i + batch_range].to(current_device) batch_distance_quantile_per_class = None optimizer.zero_grad() model.train() _, rescaled_batch_output = model(batch_x, batch_q, batch_distance_quantile_per_class=batch_distance_quantile_per_class, forward_type=constants.FORWARD_TYPE_SENTENCE_LEVEL_PREDICTION, train=True) if len(rescaled_batch_output.shape) == 1: loss = criterion(rescaled_batch_output.unsqueeze(0), batch_y) else: loss = criterion(rescaled_batch_output, batch_y) cumulative_losses.append(loss.item()) loss.backward() optimizer.step() print(f"---------------Pretraining Epoch: {e + 1}---------------") print(f"Pretraining Epoch average loss (over-pretraining set): {np.mean(cumulative_losses)}") all_epoch_cumulative_losses.extend(cumulative_losses) print(f"Pretraining Average loss across all mini-batches (all epochs, over-pretraining set): " f"{np.mean(all_epoch_cumulative_losses)}") held_out_loss_by_class_list, held_out_balanced_loss = get_loss_over_heldout_data(options, model, held_out_embeddings, held_out_labels, current_device) print(f"Pretraining Epoch: {e + 1} / Held-out set Balanced loss: {held_out_balanced_loss}") print(f"Pretraining Epoch: {e + 1} / Held-out set Balanced loss by class:") for class_label in range(model.numberOfClasses): print(f"\tClass {class_label}: {held_out_loss_by_class_list[class_label]}") is_best_running_epoch = held_out_balanced_loss <= min_held_out_balanced_loss if held_out_balanced_loss <= min_held_out_balanced_loss: min_held_out_balanced_loss = held_out_balanced_loss min_held_out_balanced_loss_epoch = e + 1 if is_best_running_epoch and total_epochs > 1: # Here, we are only updating the adaptor layer. The summary statistics and other data structures # have not (necessarily) yet been created. This pretraining is overall typically very fast, # so we simply cache the weights to memory. best_model_conv_weight = model.conv.weight.detach().clone() best_model_conv_bias = model.conv.bias.detach().clone() best_model_fc_weight = model.fc.weight.detach().clone() best_model_fc_bias = model.fc.bias.detach().clone() print(f">Cached best epoch parameters<") print(f"\tPretraining: Current min held-out set Balanced loss: " f"{min_held_out_balanced_loss} at epoch {min_held_out_balanced_loss_epoch}") print( f"\tPretraining: Min held-out set Balanced loss: " f"{min_held_out_balanced_loss} at epoch {min_held_out_balanced_loss_epoch}") if total_epochs > 1: # Add best weights model.conv.weight = nn.Parameter(best_model_conv_weight) model.conv.bias = nn.Parameter(best_model_conv_bias) model.fc.weight = nn.Parameter(best_model_fc_weight) model.fc.bias = nn.Parameter(best_model_fc_bias) if return_min_held_out_balanced_loss: return model, min_held_out_balanced_loss, min_held_out_balanced_loss_epoch else: return model def get_loss_over_heldout_data(options, model, held_out_embeddings, held_out_labels, current_device): transfer_to_cpu = current_device == torch.device('mps') if transfer_to_cpu: # The scatter operations are not currently implemented on mps, so we need to move to cpu for the time being: original_current_device = current_device current_device = torch.device('cpu') model = model.to(torch.device('cpu')) criterion = nn.NLLLoss(reduction="none") batch_size = options.batch_size default_training_q_values = torch.zeros(held_out_embeddings.shape[0], 1) + (np.e - model.q_rescale_offset) held_out_size = held_out_embeddings.shape[0] batch_num = 0 model.eval() class_size = model.numberOfClasses running_class_counts = torch.zeros(class_size).to(current_device) running_loss_sum_by_class = torch.zeros(class_size).to(current_device) with torch.no_grad(): for i in range(0, held_out_size, batch_size): batch_num += 1 batch_range = min(batch_size, held_out_size - i) batch_x = held_out_embeddings[i:i + batch_range].to(current_device) batch_y = held_out_labels[i:i + batch_range].to(current_device) batch_q = default_training_q_values[i:i + batch_range].to(current_device) batch_distance_quantile_per_class = None _, rescaled_batch_output = model(batch_x, batch_q, batch_distance_quantile_per_class=batch_distance_quantile_per_class, forward_type=constants.FORWARD_TYPE_SENTENCE_LEVEL_PREDICTION, train=True) if len(rescaled_batch_output.shape) == 1: loss = criterion(rescaled_batch_output.unsqueeze(0), batch_y) else: loss = criterion(rescaled_batch_output, batch_y) running_loss_sum_by_class += torch.zeros(class_size, device=current_device).scatter_reduce_(0, batch_y, loss, reduce='sum') running_class_counts += torch.zeros(class_size, device=current_device).scatter_add_(0, batch_y, torch.ones_like(batch_y, device=current_device, dtype=torch.float)) per_class_avg = torch.where(running_class_counts > 0, running_loss_sum_by_class / running_class_counts, torch.zeros_like(running_loss_sum_by_class, device=current_device)) balanced_loss = per_class_avg.mean() if transfer_to_cpu: model = model.to(original_current_device) return [float(x) for x in per_class_avg.detach().cpu().numpy().tolist()], balanced_loss.item()

Latest Blog Posts

MCP directory API

We provide all the information about MCP servers via our MCP API.

curl -X GET 'https://glama.ai/api/mcp/v1/servers/ReexpressAI/reexpress_mcp_server'

If you have feedback or need assistance with the MCP directory API, please join our Discord server