Module odinson.ruleutils.oracle
Expand source code
from collections import defaultdict
from typing import Dict, Optional, List
from odinson.ruleutils.queryast import *
from odinson.ruleutils.config import Vocabularies, ENTITY_FIELD, SYNTAX_FIELD
def make_transition_table(paths):
"""Gets a list of paths and returns a transition table."""
trans = defaultdict(set)
for path in paths:
for i in range(len(path) - 1):
trans[path[i]].add(path[i + 1])
return {k: list(v) for k, v in trans.items()}
def all_paths_from_root(
target: AstNode, vocabularies: Optional[Vocabularies] = None
) -> List[List[AstNode]]:
"""Returns all episodes that render an equivalent rule to target."""
results = []
for p in target.permutations():
results.append(path_from_root(p, vocabularies))
return results
def path_from_root(
target: AstNode, vocabularies: Optional[Vocabularies] = None
) -> List[AstNode]:
"""
Returns the sequence of transitions from the root of the search tree
to the specified AstNode.
"""
if isinstance(target, Query):
root = HoleQuery()
elif isinstance(target, Traversal):
root = HoleTraversal()
elif isinstance(target, Surface):
root = HoleSurface()
elif isinstance(target, Constraint):
root = HoleConstraint()
else:
raise ValueError(f"unsupported target type '{type(target)}'")
if vocabularies is None:
# If no vocabularies were provided then construct
# the minimal vocabularies required to reach the target.
vocabularies = make_minimal_vocabularies(target)
oracle = Oracle(root, target, vocabularies)
return list(oracle.traversal())
class Oracle:
def __init__(self, src: AstNode, dst: AstNode, vocabularies: Vocabularies):
self.src = src
self.dst = dst
self.vocabularies = vocabularies
# find traversal corresponding to dst node
self.dst_traversal = self.dst.preorder_traversal()
def traversal(self):
current = self.src
while current is not None:
yield current
if current == self.dst:
break
current = self.next_step(current)
def next_step(self, current: AstNode):
"""Returns the next step in the path from src to dst."""
# find position of first hole in current node's traversal
hole_position = -1
for i, n in enumerate(current.preorder_traversal()):
if n.is_hole():
hole_position = i
break
# if there is no hole then there is no next step
if hole_position < 0:
return
# consider all possible candidates
for candidate in current.expand_leftmost_hole(self.vocabularies):
traversal = candidate.preorder_traversal()
n1 = traversal[hole_position]
n2 = self.dst_traversal[hole_position]
if are_compatible(n1, n2):
# if candidate has a node in the right position of its traversal
# that is compatible with the node at the same position in the dst traversal
# then we have a winner
return candidate
def are_compatible(x: AstNode, y: AstNode) -> bool:
"""
Compares two nodes to see if they're compatible.
Note that this does not compare for equality,
because the nodes may contain holes.
"""
if isinstance(x, ExactMatcher) and isinstance(y, ExactMatcher):
return x.string == y.string
elif isinstance(x, RepeatSurface) and isinstance(y, RepeatSurface):
return x.min == y.min and x.max == y.max
elif isinstance(x, HoleSurface) and isinstance(y, Surface):
return True
else:
return type(x) == type(y)
def make_minimal_vocabularies(node: AstNode) -> Vocabularies:
"""Returns the collection of vocabularies required to build the given rule."""
vocabularies = defaultdict(set)
for n in node.preorder_traversal():
if isinstance(n, FieldConstraint):
name = n.name.string
value = n.value.string
vocabularies[name].add(value)
if isinstance(n, MentionSurface):
label = n.label.string
vocabularies[ENTITY_FIELD].add(label)
elif isinstance(n, (IncomingLabelTraversal, OutgoingLabelTraversal)):
label = n.label.string
vocabularies[SYNTAX_FIELD].add(label)
return {k: list(v) for k, v in vocabularies.items()}
Functions
def all_paths_from_root(target: AstNode, vocabularies: Optional[Dict[str, List[str]]] = None) ‑> List[List[AstNode]]
-
Returns all episodes that render an equivalent rule to target.
Expand source code
def all_paths_from_root( target: AstNode, vocabularies: Optional[Vocabularies] = None ) -> List[List[AstNode]]: """Returns all episodes that render an equivalent rule to target.""" results = [] for p in target.permutations(): results.append(path_from_root(p, vocabularies)) return results
def are_compatible(x: AstNode, y: AstNode) ‑> bool
-
Compares two nodes to see if they're compatible. Note that this does not compare for equality, because the nodes may contain holes.
Expand source code
def are_compatible(x: AstNode, y: AstNode) -> bool: """ Compares two nodes to see if they're compatible. Note that this does not compare for equality, because the nodes may contain holes. """ if isinstance(x, ExactMatcher) and isinstance(y, ExactMatcher): return x.string == y.string elif isinstance(x, RepeatSurface) and isinstance(y, RepeatSurface): return x.min == y.min and x.max == y.max elif isinstance(x, HoleSurface) and isinstance(y, Surface): return True else: return type(x) == type(y)
def make_minimal_vocabularies(node: AstNode) ‑> Dict[str, List[str]]
-
Returns the collection of vocabularies required to build the given rule.
Expand source code
def make_minimal_vocabularies(node: AstNode) -> Vocabularies: """Returns the collection of vocabularies required to build the given rule.""" vocabularies = defaultdict(set) for n in node.preorder_traversal(): if isinstance(n, FieldConstraint): name = n.name.string value = n.value.string vocabularies[name].add(value) if isinstance(n, MentionSurface): label = n.label.string vocabularies[ENTITY_FIELD].add(label) elif isinstance(n, (IncomingLabelTraversal, OutgoingLabelTraversal)): label = n.label.string vocabularies[SYNTAX_FIELD].add(label) return {k: list(v) for k, v in vocabularies.items()}
def make_transition_table(paths)
-
Gets a list of paths and returns a transition table.
Expand source code
def make_transition_table(paths): """Gets a list of paths and returns a transition table.""" trans = defaultdict(set) for path in paths: for i in range(len(path) - 1): trans[path[i]].add(path[i + 1]) return {k: list(v) for k, v in trans.items()}
def path_from_root(target: AstNode, vocabularies: Optional[Dict[str, List[str]]] = None) ‑> List[AstNode]
-
Returns the sequence of transitions from the root of the search tree to the specified AstNode.
Expand source code
def path_from_root( target: AstNode, vocabularies: Optional[Vocabularies] = None ) -> List[AstNode]: """ Returns the sequence of transitions from the root of the search tree to the specified AstNode. """ if isinstance(target, Query): root = HoleQuery() elif isinstance(target, Traversal): root = HoleTraversal() elif isinstance(target, Surface): root = HoleSurface() elif isinstance(target, Constraint): root = HoleConstraint() else: raise ValueError(f"unsupported target type '{type(target)}'") if vocabularies is None: # If no vocabularies were provided then construct # the minimal vocabularies required to reach the target. vocabularies = make_minimal_vocabularies(target) oracle = Oracle(root, target, vocabularies) return list(oracle.traversal())
Classes
class Oracle (src: AstNode, dst: AstNode, vocabularies: Dict[str, List[str]])
-
Expand source code
class Oracle: def __init__(self, src: AstNode, dst: AstNode, vocabularies: Vocabularies): self.src = src self.dst = dst self.vocabularies = vocabularies # find traversal corresponding to dst node self.dst_traversal = self.dst.preorder_traversal() def traversal(self): current = self.src while current is not None: yield current if current == self.dst: break current = self.next_step(current) def next_step(self, current: AstNode): """Returns the next step in the path from src to dst.""" # find position of first hole in current node's traversal hole_position = -1 for i, n in enumerate(current.preorder_traversal()): if n.is_hole(): hole_position = i break # if there is no hole then there is no next step if hole_position < 0: return # consider all possible candidates for candidate in current.expand_leftmost_hole(self.vocabularies): traversal = candidate.preorder_traversal() n1 = traversal[hole_position] n2 = self.dst_traversal[hole_position] if are_compatible(n1, n2): # if candidate has a node in the right position of its traversal # that is compatible with the node at the same position in the dst traversal # then we have a winner return candidate
Methods
def next_step(self, current: AstNode)
-
Returns the next step in the path from src to dst.
Expand source code
def next_step(self, current: AstNode): """Returns the next step in the path from src to dst.""" # find position of first hole in current node's traversal hole_position = -1 for i, n in enumerate(current.preorder_traversal()): if n.is_hole(): hole_position = i break # if there is no hole then there is no next step if hole_position < 0: return # consider all possible candidates for candidate in current.expand_leftmost_hole(self.vocabularies): traversal = candidate.preorder_traversal() n1 = traversal[hole_position] n2 = self.dst_traversal[hole_position] if are_compatible(n1, n2): # if candidate has a node in the right position of its traversal # that is compatible with the node at the same position in the dst traversal # then we have a winner return candidate
def traversal(self)
-
Expand source code
def traversal(self): current = self.src while current is not None: yield current if current == self.dst: break current = self.next_step(current)