320 lines
11 KiB
Python
320 lines
11 KiB
Python
import os
|
|
|
|
from cs50 import SQL
|
|
from flask import Flask, flash, jsonify, redirect, render_template, request, session
|
|
from flask_session import Session
|
|
from tempfile import mkdtemp
|
|
from werkzeug.exceptions import default_exceptions, HTTPException, InternalServerError
|
|
from werkzeug.security import check_password_hash, generate_password_hash
|
|
|
|
from helpers import apology, login_required, lookup, usd
|
|
|
|
# Configure application
|
|
app = Flask(__name__)
|
|
|
|
# Ensure templates are auto-reloaded
|
|
app.config["TEMPLATES_AUTO_RELOAD"] = True
|
|
|
|
# Ensure responses aren't cached
|
|
@app.after_request
|
|
def after_request(response):
|
|
response.headers["Cache-Control"] = "no-cache, no-store, must-revalidate"
|
|
response.headers["Expires"] = 0
|
|
response.headers["Pragma"] = "no-cache"
|
|
return response
|
|
|
|
# Custom filter
|
|
app.jinja_env.filters["usd"] = usd
|
|
|
|
# Configure session to use filesystem (instead of signed cookies)
|
|
app.config["SESSION_FILE_DIR"] = mkdtemp()
|
|
app.config["SESSION_PERMANENT"] = False
|
|
app.config["SESSION_TYPE"] = "filesystem"
|
|
Session(app)
|
|
|
|
# Configure CS50 Library to use SQLite database
|
|
db = SQL("sqlite:///finance.db")
|
|
|
|
|
|
@app.route("/")
|
|
@login_required
|
|
def index():
|
|
"""Show portfolio of stocks"""
|
|
rows = db.execute("SELECT * FROM users WHERE id = :id",
|
|
id=session["user_id"])
|
|
cash = int(rows[0]["cash"])
|
|
stocks = db.execute("SELECT symbol, SUM(shares) as sum_shares FROM transactions WHERE id = :user_id GROUP BY symbol HAVING sum_shares > 0", user_id = session["user_id"])
|
|
prices = list()
|
|
values = list()
|
|
for counter, stock in enumerate(stocks):
|
|
quote = lookup(stocks[counter]["symbol"])
|
|
prices.append(quote["price"])
|
|
value = int(quote["price"] * stocks[counter]["sum_shares"])
|
|
values.append(value)
|
|
total_value = sum(values)
|
|
return render_template("index.html", stocks = stocks, prices = prices, values = values, total_value = total_value, cash = cash)
|
|
|
|
|
|
@app.route("/buy", methods=["GET", "POST"])
|
|
@login_required
|
|
def buy():
|
|
"""Buy shares of stock"""
|
|
if request.method == "POST":
|
|
|
|
# Ensure quote and shares was submitted
|
|
if not request.form.get("symbol"):
|
|
return apology("must provide stock symbol", 400)
|
|
if not request.form.get("shares"):
|
|
return apology("must provide amount of shares", 400)
|
|
|
|
# Ensure shares is a positive integer
|
|
shares = request.form.get("shares")
|
|
if not shares.isdigit():
|
|
return apology("must enter a positive integer", 400)
|
|
|
|
shares = int(request.form.get("shares"))
|
|
# Get quote
|
|
quote = lookup(request.form.get("symbol"))
|
|
|
|
if quote == None:
|
|
return apology("please enter a valid quote", 400)
|
|
|
|
# Ensure cost is not too high
|
|
share_price = quote["price"]
|
|
cost = share_price * shares
|
|
rows = db.execute("SELECT * FROM users WHERE id = :id",
|
|
id=session["user_id"])
|
|
cash = rows[0]["cash"]
|
|
|
|
if cost > cash:
|
|
return apology("not enough cash")
|
|
|
|
symbol = quote["symbol"]
|
|
# If successful update database cash amount
|
|
db.execute("UPDATE users SET cash = cash - :cost WHERE id = :id", cost = cost, id = session["user_id"])
|
|
# Add transaction to history
|
|
db.execute("Insert INTO transactions (id, symbol, shares, price) VALUES(:id, :symbol, :shares, :cost)",
|
|
id = session["user_id"], symbol = quote["symbol"], shares = shares, cost = cost)
|
|
return redirect("/")
|
|
|
|
if quote == None:
|
|
return apology("please enter a valid quote", 200)
|
|
else:
|
|
return render_template("buy.html")
|
|
|
|
@app.route("/trending", methods=["GET", "POST"])
|
|
@login_required
|
|
def trending():
|
|
stocks = db.execute("SELECT COUNT(id) as count_id, symbol, SUM(shares) as sum_shares FROM transactions GROUP BY symbol ORDER BY count_id DESC")
|
|
prices = list()
|
|
for counter, stock in enumerate(stocks):
|
|
quote = lookup(stocks[counter]["symbol"])
|
|
prices.append(quote["price"])
|
|
return render_template("trending.html", stocks = stocks, prices = prices)
|
|
|
|
@app.route("/check", methods=["GET"])
|
|
def check():
|
|
"""Return true if username available, else false, in JSON format"""
|
|
username = request.args.get("username")
|
|
usernames = db.execute("SELECT username FROM users WHERE username = :username", username=username)
|
|
if username and not usernames:
|
|
return jsonify(True)
|
|
else:
|
|
return jsonify(False)
|
|
|
|
@app.route("/history")
|
|
@login_required
|
|
def history():
|
|
"""Show history of transactions"""
|
|
stocks = db.execute("SELECT symbol, price, date, shares FROM transactions WHERE id=:user_id", user_id=session["user_id"])
|
|
values = list()
|
|
boo = list()
|
|
shares = list()
|
|
for counter, stock in enumerate(stocks):
|
|
price = stocks[counter]["price"]
|
|
shares_count = stocks[counter]["shares"]
|
|
if shares_count < 0:
|
|
boo.append("Sold")
|
|
price = price * (-1)
|
|
shares_count = shares_count * (-1)
|
|
shares.append(shares_count)
|
|
values.append(price)
|
|
else:
|
|
boo.append("Bought")
|
|
values.append(price)
|
|
shares.append(shares_count)
|
|
|
|
return render_template("history.html", stocks = stocks, boo = boo, values = values, shares = shares)
|
|
|
|
|
|
@app.route("/login", methods=["GET", "POST"])
|
|
def login():
|
|
"""Log user in"""
|
|
|
|
# Forget any user_id
|
|
session.clear()
|
|
|
|
# User reached route via POST (as by submitting a form via POST)
|
|
if request.method == "POST":
|
|
|
|
# Ensure username was submitted
|
|
if not request.form.get("username"):
|
|
return apology("must provide username", 403)
|
|
|
|
# Ensure password was submitted
|
|
elif not request.form.get("password"):
|
|
return apology("must provide password", 403)
|
|
|
|
# Query database for username
|
|
rows = db.execute("SELECT * FROM users WHERE username = :username",
|
|
username=request.form.get("username"))
|
|
|
|
# Ensure username exists and password is correct
|
|
if len(rows) != 1 or not check_password_hash(rows[0]["hash"], request.form.get("password")):
|
|
return apology("invalid username and/or password", 403)
|
|
|
|
# Remember which user has logged in
|
|
session["user_id"] = rows[0]["id"]
|
|
|
|
# Redirect user to home page
|
|
return redirect("/")
|
|
|
|
# User reached route via GET (as by clicking a link or via redirect)
|
|
else:
|
|
return render_template("login.html")
|
|
|
|
|
|
@app.route("/logout")
|
|
def logout():
|
|
"""Log user out"""
|
|
|
|
# Forget any user_id
|
|
session.clear()
|
|
|
|
# Redirect user to login form
|
|
return redirect("/")
|
|
|
|
|
|
@app.route("/quote", methods=["GET", "POST"])
|
|
@login_required
|
|
def quote():
|
|
"""Get stock quote."""
|
|
|
|
if request.method == "POST":
|
|
|
|
# Ensure quote was submitted
|
|
if not request.form.get("symbol"):
|
|
return apology("must provide stock symbol", 400)
|
|
|
|
quote = lookup(request.form.get("symbol"))
|
|
|
|
if quote == None:
|
|
return apology("please enter a valid quote", 400)
|
|
|
|
return render_template("quoted.html", quote = quote)
|
|
else:
|
|
return render_template("quote.html")
|
|
|
|
|
|
@app.route("/register", methods=["GET", "POST"])
|
|
def register():
|
|
"""Register user"""
|
|
# Forget any user_id
|
|
session.clear()
|
|
|
|
# User reached route via POST (as by submitting a form via POST)
|
|
if request.method == "POST":
|
|
|
|
# Ensure username was submitted
|
|
if not request.form.get("username"):
|
|
return apology("must provide username", 400)
|
|
|
|
# Ensure password was submitted
|
|
elif not request.form.get("password"):
|
|
return apology("must provide password", 400)
|
|
|
|
# Ensure passwords match
|
|
elif not request.form.get("password") == request.form.get("confirmation"):
|
|
return apology("passwords must match", 400)
|
|
|
|
# Create hash for password
|
|
hash = generate_password_hash(request.form.get("password"))
|
|
|
|
# Insert username into database
|
|
result = db.execute("Insert INTO users (username, hash) VALUES(:username, :hash)", username = request.form.get("username"), hash = hash)
|
|
|
|
# Check if username is already taken
|
|
if not result:
|
|
return apology("Username has been taken", 400)
|
|
|
|
# Remember which user has logged in
|
|
session["user_id"] = result
|
|
|
|
# Redirect user to home page
|
|
return redirect("/")
|
|
|
|
# User reached route via GET (as by clicking a link or via redirect)
|
|
else:
|
|
return render_template("register.html")
|
|
|
|
|
|
|
|
@app.route("/sell", methods=["GET", "POST"])
|
|
@login_required
|
|
def sell():
|
|
"""Sell shares of stock"""
|
|
stocks = db.execute("SELECT symbol, SUM(shares) as sum_shares FROM transactions WHERE id = :user_id GROUP BY symbol HAVING sum_shares > 0", user_id = session["user_id"])
|
|
if request.method == "POST":
|
|
# Ensure symbol and shares are submitted
|
|
symbol = request.form.get("symbol")
|
|
if not symbol:
|
|
return apology("must choose a symbol", 400)
|
|
shares = request.form.get("shares")
|
|
if not shares:
|
|
return apology("must enter shares amount", 400)
|
|
|
|
|
|
# Ensure shares is a positive integer
|
|
shares = int(request.form.get("shares"))
|
|
if not type(shares) == int:
|
|
return apology("must enter an integer", 400)
|
|
if shares < 1 :
|
|
return apology("must enter a positive integer", 400)
|
|
|
|
# Ensure proper amount of shares
|
|
sum_shares = stocks[0]["sum_shares"]
|
|
if shares > sum_shares:
|
|
return apology("invalid shares amount", 400)
|
|
|
|
# Get quote
|
|
quote = lookup(request.form.get("symbol"))
|
|
|
|
# Log loss of shares and increase in cash
|
|
share_price = quote["price"]
|
|
cost = share_price * shares * (-1)
|
|
shares = shares * (-1)
|
|
|
|
symbol = quote["symbol"]
|
|
# If successful update database cash amount
|
|
db.execute("UPDATE users SET cash = cash - :cost WHERE id = :id", cost = cost, id = session["user_id"])
|
|
# Add transaction to history
|
|
db.execute("Insert INTO transactions (id, symbol, shares, price) VALUES(:id, :symbol, :shares, :cost)",
|
|
id = session["user_id"], symbol = quote["symbol"], shares = shares, cost = cost)
|
|
|
|
# Redirect user to home page
|
|
return redirect("/")
|
|
else:
|
|
return render_template("sell.html", stocks=stocks)
|
|
|
|
|
|
def errorhandler(e):
|
|
"""Handle error"""
|
|
if not isinstance(e, HTTPException):
|
|
e = InternalServerError()
|
|
return apology(e.name, e.code)
|
|
|
|
|
|
# Listen for errors
|
|
for code in default_exceptions:
|
|
app.errorhandler(code)(errorhandler)
|