Agents
This module implements the fitted Q-iteration algotithm for offline reinforcement learning.
from pycfrl import agents
- class pycfrl.agents.agents.Agent
Bases:
objectBase class for reinforcement learning agents.
Subclasses must implement the
actmethod.- __init__() None
- abstract act(z: list | ndarray, xt: list | ndarray, xtm1: list | ndarray | None = None, atm1: list | ndarray | None = None, uat: list | ndarray | None = None, **kwargs) ndarray
An abstract prototype of methods for making decisions using the agent.
- Args:
- zs (list or np.ndarray):
The observed sensitive attributes of each individual for whom the decisions are to be made. It should be a 2D list or array following the Sensitive Attributes Format.
- xt (list or np.ndarray):
The states at the current time step of each individual for whom the decisions are to be made. It should be a 2D list or array following the Single-time States Format.
- xtm1 (list or np.ndarray, optional):
The states at the previous time step of each individual for whom the decisions are to be made. It should be a 2D list or array following the Single-time States Format.
- atm1 (list or np.ndarray, optional):
The actions at the previous time step of each individual for whom the decisions are to be made. It should be a 1D list or array following the Single-time Actions Format. When both
xtm1andatm1are set toNone, the agent will consider the input to be from the initial time step of a new trajectory, and the internal preprocessor be reset if it is an instance ofSequentialPreprocessor.- uxt (list or np.ndarray, optional):
The exogenous variables for each individual’s action. It should be a 2D list or array with shape (N, 1) where N is the total number of individuals.
- Returns:
- actions (np.ndarray):
The decisions made for the individuals. It is a 1D array following the Single-time Actions Format.
- class pycfrl.agents.agents.FQI(num_actions: int, model_type: Literal['nn', 'lm'], hidden_dims: list[int] = [32], preprocessor: Preprocessor | None = None, gamma: int | float = 0.9, learning_rate: int | float = 0.1, epochs: int = 500, is_loss_monitored: bool = True, is_early_stopping_nn: bool = False, test_size_nn: int | float = 0.2, loss_monitoring_patience: int = 10, loss_monitoring_min_delta: int | float = 0.005, early_stopping_patience_nn: int = 10, early_stopping_min_delta_nn: int | float = 0.005, is_q_monitored: bool = True, is_early_stopping_q: bool = False, q_monitoring_patience: int = 5, q_monitoring_min_delta: int | float = 0.005, early_stopping_patience_q: int = 5, early_stopping_min_delta_q: int | float = 0.005)
Bases:
AgentImplementation of the fitted Q-iteration (FQI) algorithm.
FQI can be used to learn the optimal policy from offline data.
In particular, for an
FQIobject, users can specify whether to add a preprocessor internally. If a preprocessor is added internally, then theFQIobject will preprocess the input data before using the data for training (train()method) and decision-making (act()method).- References:
- __init__(num_actions: int, model_type: Literal['nn', 'lm'], hidden_dims: list[int] = [32], preprocessor: Preprocessor | None = None, gamma: int | float = 0.9, learning_rate: int | float = 0.1, epochs: int = 500, is_loss_monitored: bool = True, is_early_stopping_nn: bool = False, test_size_nn: int | float = 0.2, loss_monitoring_patience: int = 10, loss_monitoring_min_delta: int | float = 0.005, early_stopping_patience_nn: int = 10, early_stopping_min_delta_nn: int | float = 0.005, is_q_monitored: bool = True, is_early_stopping_q: bool = False, q_monitoring_patience: int = 5, q_monitoring_min_delta: int | float = 0.005, early_stopping_patience_q: int = 5, early_stopping_min_delta_q: int | float = 0.005) None
- Args:
- num_actions (int):
The total number of legit actions.
- model_type (str):
The type of the model used for learning the Q function. Can be “lm” (polynomial regression) or “nn” (neural network). Currently, only ‘nn’ is supported.
- hidden_dims (list[int], optional):
The hidden dimensions of the neural network. This argument is not used if
model_type="lm".- preprocessor (Preprocessor, optional):
A preprocessor used for preprocessing input data before using the data for training or decision-making. The preprocessor must have already been trained if it requires training. When set to
None,FQIwill directly use the input data for training or decision-making without preprocessing it.- gamma (int or float, optional):
The discount factor for the cumulative discounted reward in the objective function.
- learning_rate (int or float, optional):
The learning rate of the neural network. This argument is not used if
model_type="lm".- epochs (int, optional):
The number of training epochs for the neural network. This argument is not used if
model_type="lm".- is_loss_monitored (bool, optional):
When set to
True, will split the training data into a training set and a validation set, and will monitor the validation loss when training the neural network approximator of the Q function in each iteration. A warning will be raised if the percent absolute change in the validation loss is greater thanloss_monitoring_min_deltafor at least one of the final \(p\) epochs during neural network training, where \(p\) is specified by the argumentloss_monitoring_patience_nn. This argument is not used ifmodel_type="lm".- is_early_stopping_nn (bool, optional):
When set to
True, will split the training data into a training set and a validation set, and will enforce early stopping based on the validation loss when training the neural network approximator of the Q function in each iteration. That is, in each iteration, neural network training will stop early if the percent decrease in the validation loss is no greater thanearly_stopping_min_delta_nnfor \(q\) consecutive training epochs, where \(q\) is specified by the argumentearly_stopping_patience_nn. This argument is not used ifmodel_type="lm".- test_size_nn (int or float, optional):
An
intorfloatbetween 0 and 1 (inclusive) that specifies the proportion of the full training data that is used as the validation set for loss monitoring and early stopping. This argument is not used ifmodel_type="lm"or bothis_loss_monitoredandis_early_stopping_nnareFalse.- loss_monitoring_patience (int, optional):
The number of consecutive epochs with barely-changing validation loss at the end of neural network training that is needed for loss monitoring to not raise warnings. This argument is not used if
model_type="lm"oris_loss_monitored=False.- loss_monitoring_min_delta (int for float, optional):
The maximum amount of percent absolute change in the validation loss for it to be considered barely-changing by the loss monitoring mechanism. This argument is not used if
model_type="lm"oris_loss_monitored=False.- early_stopping_patience_nn (int, optional):
The number of consecutive epochs with barely-decreasing validation loss during neural network training that is needed for early stopping to be triggered. This argument is not used if
model_type="lm"oris_early_stopping_nn=False.- early_stopping_min_delta_nn (int for float, optional):
The maximum amount of decrease in the validation loss for it to be considered barely-decreasing by the early stopping mechanism. This argument is not used if
model_type="lm"oris_early_stopping_nn=False.- is_q_monitored (bool, optional):
When set to
True, will monitor the Q values estimated by the neural network approximator of the Q function in each iteration at all the state-action pairs present in the training trajectory. A warning will be raised if the percent absolute change in some Q value is greater thanq_monitoring_min_deltafor at least one of the final \(r\) iterations of model updates, where \(r\) is specified by the argumentq_monitoring_patience. This argument is not used ifmodel_type="lm".- is_early_stopping_q (bool, optional):
When set to
True, will monitor the Q values estimated by the neural network approximator of the Q function at all the state-action pairs present in the training trajectory, and will enforce early stopping based on the estimated Q values when training the approximated Q function. That is, FQI training will stop early if the percent absolute changes in all the predicted Q values are no greater thanearly_stopping_min_delta_qfor \(s\) consecutive iterations of model updates, where \(s\) is specified by the argumentearly_stopping_patience_q. This argument is not used ifmodel_type="lm".- q_monitoring_patience (int, optional):
The number of consecutive iterations with barely-changing estimated Q values at the end of the iterative updates that is needed for Q value monitoring to not raise warnings. This argument is not used if
model_type="lm"oris_q_monitored=False.- q_monitoring_min_delta (int for float, optional):
The maximum amount of percent absolute change in the estimated Q values for them to be considered barely-changing by the Q value monitoring mechanism. This argument is not used if
model_type="lm"oris_q_monitored=False.- early_stopping_patience_q (int, optional):
The number of consecutive iterations with barely-changing estimated Q values that is needed for early stopping to be triggered. This argument is not used if
model_type="lm"oris_early_stopping_q=False.- early_stopping_min_delta_q (int for float, optional):
The maximum amount of percent absolute change in the estimated Q values for them to be considered barely-changing by the early stopping mechanism. This argument is not used if
model_type="lm"oris_early_stopping_q=False.
- act(z: list | ndarray, xt: list | ndarray, xtm1: list | ndarray | None = None, atm1: list | ndarray | None = None, uat: list | ndarray | None = None, preprocess: bool = True) ndarray
Make decisions using the FQI agent.
Important Note when the internal preprocessor is a
SequentialPreprocessor: ASequentialPreprocessorobject internally stores the preprocessed counterfactual states from the previous function call using a states buffer, and the stored counterfactual states will be used to preprocess the inputs of the current function call. In this case, supposeact()is called on a set of transitions at time \(t\) in some trajectory. Then, at the next call ofact()for this instance of FQI, the transitions passed to the function must be from time \(t+1\) of the same trajectory to ensure that the buffer works correctly. To preprocess another trajectory, either use another instance ofFQI, or pass the initial step of the trajectory toact()withxtm1=Noneandatm1=Noneto reset the buffer.Similar issues might also arise when the internal preprocessor is some custom preprocessor that relies on buffers.
- Args:
- zs (list or np.ndarray):
The observed sensitive attributes of each individual for whom the decisions are to be made. It should be a 2D list or array following the Sensitive Attributes Format.
- xt (list or np.ndarray):
The states at the current time step of each individual for whom the decisions are to be made. It should be a 2D list or array following the Single-time States Format.
- xtm1 (list or np.ndarray, optional):
The states at the previous time step of each individual for whom the decisions are to be made. It should be a 2D list or array following the Single-time States Format.
- atm1 (list or np.ndarray, optional):
The actions at the previous time step of each individual for whom the decisions are to be made. It should be a 1D list or array following the Single-time Actions Format. When both
xtm1andatm1are set toNone, the agent will consider the input to be from the initial time step of a new trajectory, and the internal preprocessor be reset if it is an instance ofSequentialPreprocessor.- uxt (list or np.ndarray, optional):
The exogenous variables for each individual’s action. It should be a 2D list or array with shape (N, 1) where N is the total number of individuals.
- Returns:
- actions (np.ndarray):
The decisions made for the individuals. It is a 1D array following the Single-time Actions Format.
- train(zs: list | ndarray, xs: list | ndarray, actions: list | ndarray, rewards: list | ndarray, max_iter: int = 1000, preprocess: bool = True) None
Train the FQI agent.
The observed sensitive attributes
zsare used only by the internal preprocessor; it is not directly used during policy learning.- Args:
- zs (list or np.ndarray):
The observed sensitive attributes of each individual in the training data. It should be a list or array following the Sensitive Attributes Format.
- xs (list or np.ndarray):
The state trajectory used for training. It should be a list or array following the Full-trajectory States Format.
- actions (list or np.ndarray):
The action trajectory used for training. It should be a list or array following the Full-trajectory Actions Format.
- rewards (list or np.ndarray):
The reward trajectory used for training. It should be a list or array following the Full-trajectory Rewards Format.
- max_iter (int, optional):
The number of iterations for learning the Q function.
- preprocess (bool, optional):
Whether to preprocess the training data before training. When set to
False, the training data will not be preprocessed even ifpreprocessoris notNonein the constructor.