public abstract class WeightLearningApplication extends Object implements ModelApplication
Modifier and Type | Field and Description |
---|---|
protected List<Rule> |
allRules |
protected PersistedAtomManager |
atomManager
An atom manager on top of the rvDB.
|
static String |
CONFIG_PREFIX
Prefix of property keys used by this class.
|
protected Evaluator |
evaluator |
static String |
EVALUATOR_DEFAULT |
static String |
EVALUATOR_KEY
An evalautor capable of producing a score for the current weight configuration.
|
protected double[] |
expectedIncompatibility |
static String |
GROUND_RULE_STORE_DEFAULT |
static String |
GROUND_RULE_STORE_KEY
The class to use for ground rule storage.
|
protected GroundRuleStore |
groundRuleStore |
protected boolean |
inLatentMPEState |
protected boolean |
inMPEState
Flags to track if the current variable configuration is an MPE state.
|
protected GroundRuleStore |
latentGroundRuleStore |
protected TermStore |
latentTermStore |
static int |
MAX_RANDOM_WEIGHT |
static int |
MIN_ADMM_STEPS |
protected List<WeightedRule> |
mutableRules |
protected Database |
observedDB |
protected double[] |
observedIncompatibility
Corresponds 1-1 with mutableRules.
|
static boolean |
RANDOM_WEIGHTS_DEFAULT |
static String |
RANDOM_WEIGHTS_KEY
Randomize weights before running.
|
protected Reasoner |
reasoner |
static String |
REASONER_DEFAULT |
static String |
REASONER_KEY
The class to use for inference.
|
protected Database |
rvDB |
protected boolean |
supportsLatentVariables |
static String |
TERM_GENERATOR_DEFAULT |
static String |
TERM_GENERATOR_KEY
The class to use for term generator.
|
static String |
TERM_STORE_DEFAULT |
static String |
TERM_STORE_KEY
The class to use for term storage.
|
protected TermGenerator |
termGenerator |
protected TermStore |
termStore |
protected TrainingMap |
trainingMap |
Constructor and Description |
---|
WeightLearningApplication(List<Rule> rules,
Database rvDB,
Database observedDB,
boolean supportsLatentVariables) |
Modifier and Type | Method and Description |
---|---|
void |
close()
Releases all resources used by this ModelApplication.
|
protected void |
computeExpectedIncompatibility()
Compute the incompatibility in the model.
|
protected void |
computeLatentMPEState() |
double |
computeLoss()
Internal method for computing the loss at the current point before taking a step.
|
protected void |
computeMPEState() |
protected void |
computeObservedIncompatibility()
Compute the incompatibility in the model using the labels (truth values) from the observed (truth) database.
|
protected PersistedAtomManager |
createAtomManager()
Create an atom manager on top of the RV database.
|
protected abstract void |
doLearn()
Do the actual learning procedure.
|
static WeightLearningApplication |
getWLA(String className,
List<Rule> rules,
Database randomVariableDatabase,
Database observedTruthDatabase)
Construct a weight learning application given the data.
|
protected void |
initGroundModel()
Initialize all the infrastructure dealing with the ground model.
|
void |
initGroundModel(GroundRuleStore groundRuleStore)
Init the ground model using an already populated ground rule store.
|
void |
initGroundModel(Reasoner reasoner,
GroundRuleStore groundRuleStore,
TermStore termStore,
TermGenerator termGenerator,
PersistedAtomManager atomManager,
TrainingMap trainingMap)
Pass in all the ground model infrastructure.
|
protected void |
initLatentGroundModel()
The same as initGroundModel, but for latent variables.
|
void |
learn()
Learns new weights.
|
protected void |
postInitGroundModel()
A convenient place for children to do additional ground model initialization.
|
void |
setBudget(double budget)
Set a budget (give as a proportion of the max budget).
|
protected void |
setDefaultRandomVariables()
Set RandomVariableAtoms with training labels to their default values.
|
protected void |
setLabeledRandomVariables()
Set RandomVariableAtoms with training labels to their observed values.
|
public static final String CONFIG_PREFIX
public static final String REASONER_KEY
public static final String REASONER_DEFAULT
public static final String GROUND_RULE_STORE_KEY
public static final String GROUND_RULE_STORE_DEFAULT
public static final String TERM_STORE_KEY
public static final String TERM_STORE_DEFAULT
public static final String TERM_GENERATOR_KEY
public static final String TERM_GENERATOR_DEFAULT
public static final String EVALUATOR_KEY
public static final String EVALUATOR_DEFAULT
public static final String RANDOM_WEIGHTS_KEY
public static final boolean RANDOM_WEIGHTS_DEFAULT
public static final int MAX_RANDOM_WEIGHT
public static final int MIN_ADMM_STEPS
protected boolean supportsLatentVariables
protected Database rvDB
protected Database observedDB
protected PersistedAtomManager atomManager
protected List<WeightedRule> mutableRules
protected double[] observedIncompatibility
protected double[] expectedIncompatibility
protected TrainingMap trainingMap
protected Reasoner reasoner
protected GroundRuleStore groundRuleStore
protected GroundRuleStore latentGroundRuleStore
protected TermGenerator termGenerator
protected TermStore termStore
protected TermStore latentTermStore
protected Evaluator evaluator
protected boolean inMPEState
protected boolean inLatentMPEState
public void learn()
The RandomVariableAtoms
in the distribution are those
persisted in the random variable Database when this method is called. All
RandomVariableAtoms which the Model might access must be persisted in the Database.
Each such RandomVariableAtom should have a corresponding ObservedAtom
in the observed Database, unless the subclass implementation supports latent
variables.
protected abstract void doLearn()
public void setBudget(double budget)
protected void initGroundModel()
public void initGroundModel(GroundRuleStore groundRuleStore)
public void initGroundModel(Reasoner reasoner, GroundRuleStore groundRuleStore, TermStore termStore, TermGenerator termGenerator, PersistedAtomManager atomManager, TrainingMap trainingMap)
protected void postInitGroundModel()
protected void initLatentGroundModel()
protected void computeMPEState()
protected void computeLatentMPEState()
protected void computeObservedIncompatibility()
protected void computeExpectedIncompatibility()
public double computeLoss()
public void close()
ModelApplication
close
in interface ModelApplication
protected void setLabeledRandomVariables()
protected void setDefaultRandomVariables()
protected PersistedAtomManager createAtomManager()
Copyright © 2018 University of California, Santa Cruz. All rights reserved.