/**
* ProbLog Code Generator for PDSL
*
* Translates validated PDSL AST into ProbLog code.
*/
import {
Program,
Model,
Statement,
ProbabilisticFact,
ProbabilisticRule,
DeterministicFact,
AnnotatedDisjunction,
Observation,
Query,
LearningDirective,
Atom,
Literal,
Term,
} from './probabilisticAST.js';
// ============================================================================
// Generator Class
// ============================================================================
export class ProbLogGenerator {
private output: string[] = [];
private indentLevel: number = 0;
/**
* Generate ProbLog code from AST
*/
public generate(ast: Program): string {
this.output = [];
this.indentLevel = 0;
// Generate header comment
this.emit('% Generated ProbLog code from PDSL');
this.emit('% ===================================');
this.emit('');
for (const model of ast.models) {
this.generateModel(model);
}
return this.output.join('\n');
}
// ==========================================================================
// Model Generation
// ==========================================================================
private generateModel(model: Model): void {
this.emit(`% Model: ${model.name}`);
this.emit('');
// Group statements by type for better organization
const facts: Statement[] = [];
const rules: Statement[] = [];
const ads: Statement[] = [];
const observations: Statement[] = [];
const queries: Statement[] = [];
const learning: Statement[] = [];
for (const stmt of model.statements) {
switch (stmt.type) {
case 'ProbabilisticFact':
case 'DeterministicFact':
facts.push(stmt);
break;
case 'ProbabilisticRule':
rules.push(stmt);
break;
case 'AnnotatedDisjunction':
ads.push(stmt);
break;
case 'Observation':
observations.push(stmt);
break;
case 'Query':
queries.push(stmt);
break;
case 'LearningDirective':
learning.push(stmt);
break;
}
}
// Generate facts
if (facts.length > 0) {
this.emit('% Facts');
for (const stmt of facts) {
this.generateStatement(stmt);
}
this.emit('');
}
// Generate annotated disjunctions
if (ads.length > 0) {
this.emit('% Annotated Disjunctions');
for (const stmt of ads) {
this.generateStatement(stmt);
}
this.emit('');
}
// Generate rules
if (rules.length > 0) {
this.emit('% Rules');
for (const stmt of rules) {
this.generateStatement(stmt);
}
this.emit('');
}
// Generate evidence
if (observations.length > 0) {
this.emit('% Evidence');
for (const stmt of observations) {
this.generateStatement(stmt);
}
this.emit('');
}
// Generate learning directives
if (learning.length > 0) {
this.emit('% Learning');
for (const stmt of learning) {
this.generateStatement(stmt);
}
this.emit('');
}
// Generate queries
if (queries.length > 0) {
this.emit('% Queries');
for (const stmt of queries) {
this.generateStatement(stmt);
}
this.emit('');
}
}
// ==========================================================================
// Statement Generation
// ==========================================================================
private generateStatement(stmt: Statement): void {
switch (stmt.type) {
case 'ProbabilisticFact':
this.generateProbabilisticFact(stmt);
break;
case 'ProbabilisticRule':
this.generateProbabilisticRule(stmt);
break;
case 'DeterministicFact':
this.generateDeterministicFact(stmt);
break;
case 'AnnotatedDisjunction':
this.generateAnnotatedDisjunction(stmt);
break;
case 'Observation':
this.generateObservation(stmt);
break;
case 'Query':
this.generateQuery(stmt);
break;
case 'LearningDirective':
this.generateLearningDirective(stmt);
break;
}
}
private generateProbabilisticFact(fact: ProbabilisticFact): void {
const prob = this.formatProbability(fact.probability);
const atom = this.generateAtom(fact.atom);
this.emit(`${prob}::${atom}.`);
}
private generateProbabilisticRule(rule: ProbabilisticRule): void {
const prob = this.formatProbability(rule.probability);
const head = this.generateAtom(rule.head);
const body = this.generateBody(rule.body);
this.emit(`${prob}::${head} :- ${body}.`);
}
private generateDeterministicFact(fact: DeterministicFact): void {
const atom = this.generateAtom(fact.atom);
this.emit(`${atom}.`);
}
private generateAnnotatedDisjunction(ad: AnnotatedDisjunction): void {
const choices = ad.choices
.map(choice => {
const prob = this.formatProbability(choice.probability);
const atom = this.generateAtom(choice.atom);
return `${prob}::${atom}`;
})
.join('; ');
this.emit(`${choices}.`);
}
private generateObservation(obs: Observation): void {
const atom = this.generateAtom(obs.literal.atom);
const value = obs.literal.negated ? 'false' : 'true';
this.emit(`evidence(${atom}, ${value}).`);
}
private generateQuery(query: Query): void {
const atom = this.generateAtom(query.atom);
this.emit(`query(${atom}).`);
}
private generateLearningDirective(learning: LearningDirective): void {
this.emit(`% Learn parameters from ${learning.dataset}`);
this.emit(`:- learn('${learning.dataset}').`);
}
// ==========================================================================
// Logical Constructs
// ==========================================================================
private generateBody(body: Literal[]): string {
return body.map(lit => this.generateLiteral(lit)).join(', ');
}
private generateLiteral(literal: Literal): string {
const atom = this.generateAtom(literal.atom);
return literal.negated ? `\\+ ${atom}` : atom;
}
private generateAtom(atom: Atom): string {
if (atom.args.length === 0) {
return atom.predicate;
}
const args = atom.args.map(arg => this.generateTerm(arg)).join(', ');
return `${atom.predicate}(${args})`;
}
private generateTerm(term: Term): string {
switch (term.type) {
case 'Variable':
return term.name;
case 'Constant':
return term.value;
case 'Number':
return term.value.toString();
case 'Atom':
return this.generateAtom(term);
}
}
// ==========================================================================
// Formatting Helpers
// ==========================================================================
private formatProbability(prob: number | string): string {
if (typeof prob === 'string') {
// Variable name (for learning)
return prob;
}
// Format number with appropriate precision
if (prob === 0 || prob === 1) {
return prob.toString();
}
// Use up to 6 decimal places, removing trailing zeros
return parseFloat(prob.toFixed(6)).toString();
}
private emit(line: string): void {
const indent = ' '.repeat(this.indentLevel);
this.output.push(indent + line);
}
private indent(): void {
this.indentLevel++;
}
private dedent(): void {
this.indentLevel = Math.max(0, this.indentLevel - 1);
}
}
// ============================================================================
// Utility Functions
// ============================================================================
/**
* Generate ProbLog code from PDSL AST
*/
export function generate(ast: Program): string {
const generator = new ProbLogGenerator();
return generator.generate(ast);
}
/**
* Generate ProbLog with custom options
*/
export interface GeneratorOptions {
includeComments?: boolean;
groupByType?: boolean;
indentSpaces?: number;
}
export function generateWithOptions(
ast: Program,
options: GeneratorOptions = {}
): string {
// For now, just use default generator
// In future, can customize based on options
return generate(ast);
}