Add wiring

This commit is contained in:
jmsgrogan 2023-01-26 14:26:33 +00:00
parent 20c13c1cdf
commit a74dfd5f5f
20 changed files with 553 additions and 14 deletions

View file

@ -14,6 +14,7 @@ list(APPEND HEADERS
BasicQuantumGates.h
visuals/BlochSphereNode.h
visuals/QuantumCircuitNode.h
visuals/QuantumCircuitElementNode.h
visuals/QuantumGateNode.h
visuals/QuantumWireNode.h
visuals/QuantumTerminalNode.h
@ -33,6 +34,7 @@ list(APPEND SOURCES
BasicQuantumGates.cpp
visuals/BlochSphereNode.cpp
visuals/QuantumCircuitNode.cpp
visuals/QuantumCircuitElementNode.cpp
visuals/QuantumGateNode.cpp
visuals/QuantumWireNode.cpp
visuals/QuantumTerminalNode.cpp

View file

@ -2,6 +2,8 @@
#include "QuantumGate.h"
#include "FileLogger.h"
void QuantumCircuit::addInputTerminal(QuantumTerminalPtr terminal)
{
mInputTerminals.push_back(terminal.get());
@ -44,4 +46,89 @@ const std::vector<QuantumGate*>& QuantumCircuit::getLogicGates() const
const std::vector<QuantumWire*>& QuantumCircuit::getQuantumWires() const
{
return mWires;
}
void QuantumCircuit::buildWireConnections()
{
for (const auto& wire : mWires)
{
bool input_set{ false };
bool output_set{ false };
for (auto input : mInputTerminals)
{
if (wire->getInput() == input)
{
input->setConnection(wire);
input_set = true;
break;
}
}
for (auto output : mOutputTerminals)
{
if (wire->getOutput() == output)
{
output->setConnection(wire);
output_set = true;
break;
}
}
if (input_set && output_set)
{
break;
}
for (auto gate : mGates)
{
if (wire->getOutput() == gate)
{
for (std::size_t idx = 0; idx < gate->getNumInputs(); idx++)
{
if (gate->getInput(idx) == nullptr)
{
gate->setAtInput(idx, wire);
output_set = true;
break;
}
}
}
if (wire->getInput() == gate)
{
for (std::size_t idx = 0; idx < gate->getNumOutputs(); idx++)
{
if (gate->getOutput(idx) == nullptr)
{
gate->setAtOutput(idx, wire);
input_set = true;
break;
}
}
}
if (input_set && output_set)
{
break;
}
}
}
if (!connectivityIsValid())
{
MLOG_ERROR("Input circuit does not have complete connectivity");
}
}
bool QuantumCircuit::connectivityIsValid() const
{
for (const auto& element : mElements)
{
if (!element->isFullyConnected())
{
return false;
}
}
return true;
}

View file

@ -18,6 +18,8 @@ public:
void addLogicGate(QuantumGatePtr gate);
void buildWireConnections();
const std::vector<QuantumTerminal*>& getInputTerminals() const;
const std::vector<QuantumTerminal*>& getOutputTerminals() const;
@ -27,6 +29,8 @@ public:
const std::vector<QuantumWire*>& getQuantumWires() const;
private:
bool connectivityIsValid() const;
std::vector<QuantumTerminal*> mInputTerminals;
std::vector<QuantumTerminal*> mOutputTerminals;

View file

@ -16,5 +16,7 @@ public:
virtual ~QuantumCircuitElement() = default;
virtual bool isFullyConnected() const = 0;
virtual QuantumCircuitElement::Type getType() const = 0;
};

View file

@ -27,6 +27,7 @@ std::unique_ptr<QuantumCircuit> QuantumCircuitReader::read(const std::string& co
onLine(line, cursor);
cursor++;
}
circuit->buildWireConnections();
return circuit;
}

View file

@ -3,6 +3,7 @@
#include "QuantumCircuitElement.h"
#include "QuantumWire.h"
#include <vector>
class QuantumGate : public QuantumCircuitElement
@ -27,6 +28,10 @@ public:
virtual AbstractQuantumWire* getOutput(std::size_t idx) const = 0;
virtual void setAtInput(std::size_t idx, AbstractQuantumWire* value) = 0;
virtual void setAtOutput(std::size_t idx, AbstractQuantumWire* value) = 0;
virtual GateType getGateType() const = 0;
Type getType() const override
@ -52,9 +57,29 @@ public:
AbstractQuantumWire* getOutput(std::size_t idx) const override;
void setAtInput(std::size_t idx, AbstractQuantumWire* value);
void setAtInput(std::size_t idx, AbstractQuantumWire* value) override;
void setAtOutput(std::size_t idx, AbstractQuantumWire* value);
void setAtOutput(std::size_t idx, AbstractQuantumWire* value) override;
bool isFullyConnected() const override
{
for (const auto input : mInputs)
{
if (input == nullptr)
{
return false;
}
}
for (const auto output : mOutputs)
{
if (output == nullptr)
{
return false;
}
}
return true;
}
private:
std::size_t mNumIn{ 1 };

View file

@ -23,6 +23,16 @@ public:
Type getType() const override;
TerminalType getTerminalType() const
{
return mType;
}
bool isFullyConnected() const override
{
return mConnection;
}
const Qubit& getValue() const;
void setConnection(QuantumWire* connection);

View file

@ -1,12 +1,18 @@
#include "QuantumWire.h"
QuantumWire::QuantumWire(QuantumCircuitElement* input, QuantumCircuitElement* output)
AbstractQuantumWire::AbstractQuantumWire(QuantumCircuitElement* input, QuantumCircuitElement* output)
: mInput(input),
mOutput(output)
{
}
QuantumWire::QuantumWire(QuantumCircuitElement* input, QuantumCircuitElement* output)
: AbstractQuantumWire(input, output)
{
}
QuantumCircuitElement* QuantumWire::getInput() const
{
return mInput;

View file

@ -11,9 +11,20 @@ public:
CLASSICAL
};
AbstractQuantumWire(QuantumCircuitElement* input, QuantumCircuitElement* output);
virtual ~AbstractQuantumWire() = default;
virtual WireType getWireType() const = 0;
bool isFullyConnected() const override
{
return mInput && mOutput;
}
protected:
QuantumCircuitElement* mInput{ nullptr };
QuantumCircuitElement* mOutput{ nullptr };
};
class QuantumWire : public AbstractQuantumWire
@ -26,9 +37,6 @@ public:
Type getType() const override;
WireType getWireType() const override;
private:
QuantumCircuitElement* mInput{ nullptr };
QuantumCircuitElement* mOutput{ nullptr };
};
using QuantumWirePtr = std::unique_ptr<QuantumWire>;

View file

@ -0,0 +1,12 @@
#include "QuantumCircuitElementNode.h"
QuantumCircuitElementNode::QuantumCircuitElementNode(const Transform& t)
: AbstractVisualNode(t)
{
}
QuantumCircuitElementNode::~QuantumCircuitElementNode()
{
}

View file

@ -0,0 +1,15 @@
#pragma once
#include "AbstractVisualNode.h"
#include "QuantumWire.h"
class QuantumCircuitElementNode : public AbstractVisualNode
{
public:
QuantumCircuitElementNode(const Transform& t = {});
virtual ~QuantumCircuitElementNode();
virtual Point getConnectionLocation(AbstractQuantumWire* wire) const = 0;
};

View file

@ -1,9 +1,16 @@
#include "QuantumCircuitNode.h"
#include "QuantumCircuit.h"
#include "QuantumCircuitElement.h"
#include "QuantumWire.h"
#include "QuantumCircuitElementNode.h"
#include "QuantumTerminalNode.h"
#include "QuantumWireNode.h"
#include "QuantumGateNode.h"
QuantumCircuitNode::QuantumCircuitNode(const Transform& t)
: AbstractVisualNode(t)
{
}
@ -23,9 +30,40 @@ void QuantumCircuitNode::update(SceneInfo* sceneInfo)
}
}
void QuantumCircuitNode::buildWireConnections()
{
mWireInputConnections.clear();
mWireOutputConnections.clear();
for (auto terminal : mContent->getInputTerminals())
{
mWireOutputConnections[terminal->getConnection()] = terminal;
}
for (auto gate : mContent->getLogicGates())
{
for (std::size_t idx = 0; idx < gate->getNumInputs(); idx++)
{
mWireInputConnections[gate->getInput(idx)] = gate;
}
for (std::size_t idx = 0; idx < gate->getNumOutputs(); idx++)
{
mWireOutputConnections[gate->getOutput(idx)] = gate;
}
}
for (auto terminal : mContent->getOutputTerminals())
{
mWireInputConnections[terminal->getConnection()] = terminal;
}
}
void QuantumCircuitNode::createOrUpdateGeometry(SceneInfo*)
{
double terminal_vertical_spacing = 100;
buildWireConnections();
double wire_vertical_spacing = 50;
double terminal_left_margin = 10;
double terminal_y = 10;
@ -37,8 +75,56 @@ void QuantumCircuitNode::createOrUpdateGeometry(SceneInfo*)
terminal_node->setContent(terminal);
addChild(terminal_node.get());
mNodesForContent[terminal] = terminal_node.get();
mInputTerminalNodes.push_back(std::move(terminal_node));
terminal_y += terminal_vertical_spacing;
terminal_y += wire_vertical_spacing;
}
terminal_y = 10;
for (auto terminal : mContent->getOutputTerminals())
{
Point loc{ 150.0, terminal_y };
auto terminal_node = std::make_unique<QuantumTerminalNode>(Transform(loc));
terminal_node->setContent(terminal);
addChild(terminal_node.get());
mNodesForContent[terminal] = terminal_node.get();
mOutputTerminalNodes.push_back(std::move(terminal_node));
terminal_y += wire_vertical_spacing;
}
double gate_y = 0;
for (auto gate : mContent->getLogicGates())
{
Point loc{ 75.0, gate_y };
auto gate_node = std::make_unique<QuantumGateNode>(Transform(loc));
gate_node->setContent(gate);
addChild(gate_node.get());
mNodesForContent[gate] = gate_node.get();
mGateNodes.push_back(std::move(gate_node));
gate_y += wire_vertical_spacing;
}
for (auto wire : mContent->getQuantumWires())
{
auto start_node = mNodesForContent[mWireOutputConnections[wire]];
auto end_node = mNodesForContent[mWireInputConnections[wire]];
auto wire_node = std::make_unique<QuantumWireNode>(Transform());
auto start_loc = start_node->getConnectionLocation(wire);
wire_node->setInputLocation(start_loc);
auto end_loc = end_node->getConnectionLocation(wire);
auto straight_end_loc = Point(end_loc.getX(), start_loc.getY());
wire_node->setOutputLocation(straight_end_loc);
addChild(wire_node.get());
mWireNodes.push_back(std::move(wire_node));
}
}

View file

@ -2,8 +2,16 @@
#include "AbstractVisualNode.h"
#include <unordered_map>
class QuantumCircuit;
class QuantumCircuitElement;
class AbstractQuantumWire;
class QuantumTerminalNode;
class QuantumGateNode;
class QuantumWireNode;
class QuantumCircuitElementNode;
class QuantumCircuitNode : public AbstractVisualNode
{
@ -16,8 +24,18 @@ public:
private:
void createOrUpdateGeometry(SceneInfo* sceneInfo);
void buildWireConnections();
bool mContentDirty{ true };
QuantumCircuit* mContent{ nullptr };
std::vector<std::unique_ptr<QuantumTerminalNode> > mInputTerminalNodes;
std::vector<std::unique_ptr<QuantumTerminalNode> > mOutputTerminalNodes;
std::vector<std::unique_ptr<QuantumGateNode> > mGateNodes;
std::vector<std::unique_ptr<QuantumWireNode> > mWireNodes;
std::unordered_map<AbstractQuantumWire*, QuantumCircuitElement*> mWireInputConnections;
std::unordered_map<AbstractQuantumWire*, QuantumCircuitElement*> mWireOutputConnections;
std::unordered_map<QuantumCircuitElement*, QuantumCircuitElementNode*> mNodesForContent;
};

View file

@ -0,0 +1,111 @@
#include "QuantumGateNode.h"
#include "RectangleNode.h"
#include "EquationNode.h"
#include "QuantumGate.h"
#include "LatexMathExpression.h"
QuantumGateNode::QuantumGateNode(const Transform& t)
: QuantumCircuitElementNode(t)
{
}
QuantumGateNode::~QuantumGateNode()
{
}
void QuantumGateNode::setContent(QuantumGate* gate)
{
mContent = gate;
mContentDirty = true;
}
void QuantumGateNode::update(SceneInfo* sceneInfo)
{
if (mContentDirty)
{
createOrUpdateGeometry(sceneInfo);
mContentDirty = false;
}
}
void QuantumGateNode::createOrUpdateGeometry(SceneInfo* sceneInfo)
{
if (!mBody)
{
mBody = std::make_unique<RectangleNode>(Point(0, 0), mBodyWidth, mBodyHeight);
addChild(mBody.get());
}
if (!mLabel)
{
mLabel = std::make_unique<EquationNode>(Point(mBodyWidth /3.0, mBodyHeight / 3.0));
std::string label_content;
if (mContent->getGateType() == QuantumGate::GateType::X)
{
label_content = "X";
}
else if (mContent->getGateType() == QuantumGate::GateType::Y)
{
label_content = "Y";
}
else if (mContent->getGateType() == QuantumGate::GateType::Z)
{
label_content = "Z";
}
else if (mContent->getGateType() == QuantumGate::GateType::H)
{
label_content = "H";
}
else
{
label_content = "U";
}
mLabelExpression = std::make_unique<LatexMathExpression>(label_content);
mLabel->setContent(mLabelExpression.get());
addChild(mLabel.get());
}
}
Point QuantumGateNode::getConnectionLocation(AbstractQuantumWire* wire) const
{
bool is_input{ false };
std::size_t connection_id{ 0 };
for (std::size_t idx = 0; idx < mContent->getNumInputs(); idx++)
{
if (mContent->getInput(idx) == wire)
{
is_input = true;
connection_id = idx;
break;
}
}
for (std::size_t idx = 0; idx < mContent->getNumOutputs(); idx++)
{
if (mContent->getOutput(idx) == wire)
{
connection_id = idx;
break;
}
}
Point loc;
if (is_input)
{
loc = Point(0.0, mBodyHeight/2.0);
}
else
{
loc = Point(mBodyWidth, mBodyHeight / 2.0);
}
loc.move(mTransform.getLocation().getX(), mTransform.getLocation().getY());
return loc;
}

View file

@ -0,0 +1,34 @@
#pragma once
#include "QuantumCircuitElementNode.h"
class RectangleNode;
class EquationNode;
class QuantumGate;
class LatexMathExpression;
class QuantumGateNode : public QuantumCircuitElementNode
{
public:
QuantumGateNode(const Transform& t = {});
virtual ~QuantumGateNode();
Point getConnectionLocation(AbstractQuantumWire* wire) const override;
void setContent(QuantumGate* gate);
void update(SceneInfo* sceneInfo);
private:
void createOrUpdateGeometry(SceneInfo* sceneInfo);
QuantumGate* mContent{ nullptr };
bool mContentDirty{ true };
std::unique_ptr<RectangleNode> mBody;
double mBodyWidth = 30;
double mBodyHeight = 24;
std::unique_ptr<LatexMathExpression> mLabelExpression;
std::unique_ptr<EquationNode> mLabel;
};

View file

@ -5,7 +5,7 @@
#include "LatexMathExpression.h"
QuantumTerminalNode::QuantumTerminalNode(const Transform& transform)
: AbstractVisualNode(transform)
: QuantumCircuitElementNode(transform)
{
}
@ -27,7 +27,7 @@ void QuantumTerminalNode::update(SceneInfo* sceneInfo)
void QuantumTerminalNode::createOrUpdateGeometry(SceneInfo* sceneInfo)
{
if (!mLabel)
if (!mLabel && mContent->getTerminalType() != QuantumTerminal::TerminalType::OUTPUT)
{
const auto value = mContent->getValue();
std::string label;
@ -50,4 +50,11 @@ void QuantumTerminalNode::createOrUpdateGeometry(SceneInfo* sceneInfo)
addChild(mLabel.get());
}
}
}
Point QuantumTerminalNode::getConnectionLocation(AbstractQuantumWire*) const
{
auto left = mTransform.getLocation();
left.move(mWidth, mHeight/2.0);
return left;
}

View file

@ -1,16 +1,18 @@
#pragma once
#include "AbstractVisualNode.h"
#include "QuantumCircuitElementNode.h"
class QuantumTerminal;
class EquationNode;
class LatexMathExpression;
class QuantumTerminalNode : public AbstractVisualNode
class QuantumTerminalNode : public QuantumCircuitElementNode
{
public:
QuantumTerminalNode(const Transform& transform);
Point getConnectionLocation(AbstractQuantumWire* wire) const override;
void setContent(QuantumTerminal* terminal);
void update(SceneInfo* sceneInfo);
@ -20,6 +22,9 @@ private:
QuantumTerminal* mContent{ nullptr };
bool mContentDirty{ true };
double mWidth = 20.0;
double mHeight = 10.0;
std::unique_ptr<LatexMathExpression> mLabelExpression;
std::unique_ptr<EquationNode> mLabel;
};

View file

@ -0,0 +1,73 @@
#include "QuantumWireNode.h"
#include "QuantumWire.h"
#include "LineNode.h"
QuantumWireNode::QuantumWireNode(const Transform& t)
{
}
QuantumWireNode::~QuantumWireNode()
{
}
void QuantumWireNode::setContent(QuantumWire* content)
{
mContent = content;
mContentDirty = true;
}
void QuantumWireNode::setInputLocation(const Point& point)
{
if (mInputLocation != point)
{
mContentDirty = true;
mInputLocation = point;
}
}
void QuantumWireNode::setOutputLocation(const Point& point)
{
if (mOutputLocation != point)
{
mContentDirty = true;
mOutputLocation = point;
}
}
void QuantumWireNode::update(SceneInfo* sceneInfo)
{
if (mContentDirty)
{
createOrUpdateGeometry(sceneInfo);
mContentDirty = false;
}
}
void QuantumWireNode::createOrUpdateGeometry(SceneInfo* sceneInfo)
{
if (!mLine)
{
auto loc = mOutputLocation;
loc.move(-mInputLocation.getX(), -mInputLocation.getY(), -mInputLocation.getZ());
std::vector<Point> points;
if (loc.getY() == 0.0)
{
points = { loc };
}
else
{
auto join0 = Point(loc.getX() * 3.0 / 4.0, 0.0);
auto join1 = Point(loc.getX() * 3.0 / 4.0, loc.getY());
points = { join0, join1 , loc };
}
mLine = std::make_unique<LineNode>(Transform(mInputLocation), points);
addChild(mLine.get());
}
}

View file

@ -0,0 +1,33 @@
#pragma once
#include "AbstractVisualNode.h"
class QuantumWire;
class LineNode;
class QuantumWireNode : public AbstractVisualNode
{
public:
QuantumWireNode(const Transform& t = {});
virtual ~QuantumWireNode();
void setInputLocation(const Point& point);
void setOutputLocation(const Point& point);
void setContent(QuantumWire* content);
void update(SceneInfo* sceneInfo);
private:
void createOrUpdateGeometry(SceneInfo* sceneInfo);
QuantumWire* mContent{ nullptr };
bool mContentDirty{ true };
Point mInputLocation;
Point mOutputLocation;
std::unique_ptr<LineNode> mLine;
};

View file

@ -15,7 +15,7 @@ TEST_CASE(TestQuantumCircuitParsing, "quantum_computing")
TestRenderer renderer(100, 100);
auto node = std::make_unique<QuantumCircuitNode>(Point(10, 10));
auto node = std::make_unique<QuantumCircuitNode>(Point(20, 20));
node->setContent(circuit.get());