This commit is contained in:
Lieuwe
2011-03-18 17:04:51 +01:00
parent bc8af4e210
commit 9c39875ef0
1431 changed files with 505090 additions and 3 deletions

1
.gitignore vendored
View File

@@ -9,3 +9,4 @@ build/stamps/*
utils/*
*.log
build/ext_*
src/python/stdlib/*

49
addzip.py Normal file
View File

@@ -0,0 +1,49 @@
import zipfile
import os
import os.path
import sys
if(len(sys.argv)>1 and sys.argv[1]=="--clean"):
print "cleaning"
for dirpath,dirnames,filenames in os.walk("./src/python/stdlib"):
for fname in filenames:
if(fname.endswith(".pyc") or fname.endswith(".pyo")):
os.remove(os.path.join(dirpath,fname))
raise SystemExit
print "zipping stdlib"
fid=zipfile.ZipFile("./build/stdlib.zip","w",zipfile.ZIP_DEFLATED)
#ZipFile.write(filename)
files=os.walk("./src/python/stdlib")
num=0
pn=0
for dirpath,dirnames,filenames in files:
for fname in filenames:
if(fname.endswith(".py")):
continue
fid.write(os.path.join(dirpath,fname))
num+=1
if(num-5>=pn):
pn=num
print "%d done."%num
print "writing zipfile"
fid.close()
raise SystemExit
"""not needed."""
print "generating pystdlib.h"
with open("stdlib.zip","r") as fid:
with open("./includes/pystdlib.h","w") as outfid:
outfid.write("unsigned char tpt_console_stdlib[] = {")
tmp=0
for char in fid.read():
outfid.write(hex(ord(char)))
outfid.write(",")
tmp+=1
outfid.write("};\n")
outfid.write("size_t tpt_console_stdlibsize=%d;"%tmp)
print "done"

Binary file not shown.

BIN
build/stdlib.zip Normal file

Binary file not shown.

View File

@@ -42,7 +42,7 @@ else:
print "internal"
#unsigned char tpt_console_pyc[] = { 0x1B, 0x57};
lst=[]
compileall.compile_dir("./src/python", force=1)
compileall.compile_dir("./src/python", force=0)
print "generating pyconsole.h"

View File

@@ -29,6 +29,7 @@
#ifdef PYCONSOLE
#include "Python.h"
#include "pyconsole.h"
//#include "pystdlib.h"
#endif
#include <stdio.h>
@@ -1940,6 +1941,29 @@ emb_delete(PyObject *self, PyObject *args)
return Py_BuildValue("i",1);
}
static PyObject*
emb_set_pressure(PyObject *self, PyObject *args)
{
////SDL_EnableKeyRepeat(delay,interval)
int x,y,press;
if(!PyArg_ParseTuple(args, "iii:set_pressure",&x,&y,&press))
return NULL;
pv[y/CELL][x/CELL]=press;
return Py_BuildValue("i",1);
}
static PyObject*
emb_set_velocity(PyObject *self, PyObject *args)
{
////SDL_EnableKeyRepeat(delay,interval)
int x,y,xv,yv;
if(!PyArg_ParseTuple(args, "iiii:set_velocity",&x,&y,&xv,&yv))
return NULL;
vx[y/CELL][x/CELL]=xv;
vy[y/CELL][x/CELL]=yv;
return Py_BuildValue("i",1);
}
static PyMethodDef EmbMethods[] = { //WARNING! don't forget to register your function here!
{"create", (PyCFunction)emb_create, METH_VARARGS|METH_KEYWORDS, "create a particle."},
{"log", (PyCFunction)emb_log, METH_VARARGS, "logs an error string to the console."},
@@ -1977,6 +2001,8 @@ static PyMethodDef EmbMethods[] = { //WARNING! don't forget to register your fun
{"get_modifier", (PyCFunction)emb_get_modifier, METH_VARARGS, "get pressed modifier keys"},
{"set_keyrepeat", (PyCFunction)emb_set_keyrepeat, METH_VARARGS, "set key repeat rate."},
{"delete", (PyCFunction)emb_delete, METH_VARARGS, "delete a particle"},
{"set_pressure", (PyCFunction)emb_set_pressure, METH_VARARGS, "set pressure"},
{"set_velocity", (PyCFunction)emb_set_velocity, METH_VARARGS, "set velocity"},
{NULL, NULL, 0, NULL}
};
#endif
@@ -2027,11 +2053,20 @@ int main(int argc, char *argv[])
fmt.userdata = NULL;
#ifdef PYCONSOLE
//dump tpt_console_stdlib to file here
/*#include "pystdlib.h"
FILE* fid;
fid=fopen("stdlib.zip","w");
//fputs(tpt_console_stdlib,fid);
fwrite(&tpt_console_stdlib,1,tpt_console_stdlibsize,fid);
fclose(fid);*/
//crashes gcc
//initialise python console
Py_Initialize();
Py_InitModule("tpt", EmbMethods);
//change the path to find all the correct modules
PyRun_SimpleString("import sys\nsys.path.append('.')");
PyRun_SimpleString("import sys\nsys.path.append('./stdlib.zip')\nsys.path.append('.')");
PyRun_SimpleString("print 'python present.'");
//load the console module and whatnot
#ifdef PYEXT

View File

@@ -0,0 +1,600 @@
"""HTTP server base class.
Note: the class in this module doesn't implement any HTTP request; see
SimpleHTTPServer for simple implementations of GET, HEAD and POST
(including CGI scripts). It does, however, optionally implement HTTP/1.1
persistent connections, as of version 0.3.
Contents:
- BaseHTTPRequestHandler: HTTP request handler base class
- test: test function
XXX To do:
- log requests even later (to capture byte count)
- log user-agent header and other interesting goodies
- send error log to separate file
"""
# See also:
#
# HTTP Working Group T. Berners-Lee
# INTERNET-DRAFT R. T. Fielding
# <draft-ietf-http-v10-spec-00.txt> H. Frystyk Nielsen
# Expires September 8, 1995 March 8, 1995
#
# URL: http://www.ics.uci.edu/pub/ietf/http/draft-ietf-http-v10-spec-00.txt
#
# and
#
# Network Working Group R. Fielding
# Request for Comments: 2616 et al
# Obsoletes: 2068 June 1999
# Category: Standards Track
#
# URL: http://www.faqs.org/rfcs/rfc2616.html
# Log files
# ---------
#
# Here's a quote from the NCSA httpd docs about log file format.
#
# | The logfile format is as follows. Each line consists of:
# |
# | host rfc931 authuser [DD/Mon/YYYY:hh:mm:ss] "request" ddd bbbb
# |
# | host: Either the DNS name or the IP number of the remote client
# | rfc931: Any information returned by identd for this person,
# | - otherwise.
# | authuser: If user sent a userid for authentication, the user name,
# | - otherwise.
# | DD: Day
# | Mon: Month (calendar name)
# | YYYY: Year
# | hh: hour (24-hour format, the machine's timezone)
# | mm: minutes
# | ss: seconds
# | request: The first line of the HTTP request as sent by the client.
# | ddd: the status code returned by the server, - if not available.
# | bbbb: the total number of bytes sent,
# | *not including the HTTP/1.0 header*, - if not available
# |
# | You can determine the name of the file accessed through request.
#
# (Actually, the latter is only true if you know the server configuration
# at the time the request was made!)
__version__ = "0.3"
__all__ = ["HTTPServer", "BaseHTTPRequestHandler"]
import sys
import time
import socket # For gethostbyaddr()
from warnings import filterwarnings, catch_warnings
with catch_warnings():
if sys.py3kwarning:
filterwarnings("ignore", ".*mimetools has been removed",
DeprecationWarning)
import mimetools
import SocketServer
# Default error message template
DEFAULT_ERROR_MESSAGE = """\
<head>
<title>Error response</title>
</head>
<body>
<h1>Error response</h1>
<p>Error code %(code)d.
<p>Message: %(message)s.
<p>Error code explanation: %(code)s = %(explain)s.
</body>
"""
DEFAULT_ERROR_CONTENT_TYPE = "text/html"
def _quote_html(html):
return html.replace("&", "&amp;").replace("<", "&lt;").replace(">", "&gt;")
class HTTPServer(SocketServer.TCPServer):
allow_reuse_address = 1 # Seems to make sense in testing environment
def server_bind(self):
"""Override server_bind to store the server name."""
SocketServer.TCPServer.server_bind(self)
host, port = self.socket.getsockname()[:2]
self.server_name = socket.getfqdn(host)
self.server_port = port
class BaseHTTPRequestHandler(SocketServer.StreamRequestHandler):
"""HTTP request handler base class.
The following explanation of HTTP serves to guide you through the
code as well as to expose any misunderstandings I may have about
HTTP (so you don't need to read the code to figure out I'm wrong
:-).
HTTP (HyperText Transfer Protocol) is an extensible protocol on
top of a reliable stream transport (e.g. TCP/IP). The protocol
recognizes three parts to a request:
1. One line identifying the request type and path
2. An optional set of RFC-822-style headers
3. An optional data part
The headers and data are separated by a blank line.
The first line of the request has the form
<command> <path> <version>
where <command> is a (case-sensitive) keyword such as GET or POST,
<path> is a string containing path information for the request,
and <version> should be the string "HTTP/1.0" or "HTTP/1.1".
<path> is encoded using the URL encoding scheme (using %xx to signify
the ASCII character with hex code xx).
The specification specifies that lines are separated by CRLF but
for compatibility with the widest range of clients recommends
servers also handle LF. Similarly, whitespace in the request line
is treated sensibly (allowing multiple spaces between components
and allowing trailing whitespace).
Similarly, for output, lines ought to be separated by CRLF pairs
but most clients grok LF characters just fine.
If the first line of the request has the form
<command> <path>
(i.e. <version> is left out) then this is assumed to be an HTTP
0.9 request; this form has no optional headers and data part and
the reply consists of just the data.
The reply form of the HTTP 1.x protocol again has three parts:
1. One line giving the response code
2. An optional set of RFC-822-style headers
3. The data
Again, the headers and data are separated by a blank line.
The response code line has the form
<version> <responsecode> <responsestring>
where <version> is the protocol version ("HTTP/1.0" or "HTTP/1.1"),
<responsecode> is a 3-digit response code indicating success or
failure of the request, and <responsestring> is an optional
human-readable string explaining what the response code means.
This server parses the request and the headers, and then calls a
function specific to the request type (<command>). Specifically,
a request SPAM will be handled by a method do_SPAM(). If no
such method exists the server sends an error response to the
client. If it exists, it is called with no arguments:
do_SPAM()
Note that the request name is case sensitive (i.e. SPAM and spam
are different requests).
The various request details are stored in instance variables:
- client_address is the client IP address in the form (host,
port);
- command, path and version are the broken-down request line;
- headers is an instance of mimetools.Message (or a derived
class) containing the header information;
- rfile is a file object open for reading positioned at the
start of the optional input data part;
- wfile is a file object open for writing.
IT IS IMPORTANT TO ADHERE TO THE PROTOCOL FOR WRITING!
The first thing to be written must be the response line. Then
follow 0 or more header lines, then a blank line, and then the
actual data (if any). The meaning of the header lines depends on
the command executed by the server; in most cases, when data is
returned, there should be at least one header line of the form
Content-type: <type>/<subtype>
where <type> and <subtype> should be registered MIME types,
e.g. "text/html" or "text/plain".
"""
# The Python system version, truncated to its first component.
sys_version = "Python/" + sys.version.split()[0]
# The server software version. You may want to override this.
# The format is multiple whitespace-separated strings,
# where each string is of the form name[/version].
server_version = "BaseHTTP/" + __version__
# The default request version. This only affects responses up until
# the point where the request line is parsed, so it mainly decides what
# the client gets back when sending a malformed request line.
# Most web servers default to HTTP 0.9, i.e. don't send a status line.
default_request_version = "HTTP/0.9"
def parse_request(self):
"""Parse a request (internal).
The request should be stored in self.raw_requestline; the results
are in self.command, self.path, self.request_version and
self.headers.
Return True for success, False for failure; on failure, an
error is sent back.
"""
self.command = None # set in case of error on the first line
self.request_version = version = self.default_request_version
self.close_connection = 1
requestline = self.raw_requestline
if requestline[-2:] == '\r\n':
requestline = requestline[:-2]
elif requestline[-1:] == '\n':
requestline = requestline[:-1]
self.requestline = requestline
words = requestline.split()
if len(words) == 3:
[command, path, version] = words
if version[:5] != 'HTTP/':
self.send_error(400, "Bad request version (%r)" % version)
return False
try:
base_version_number = version.split('/', 1)[1]
version_number = base_version_number.split(".")
# RFC 2145 section 3.1 says there can be only one "." and
# - major and minor numbers MUST be treated as
# separate integers;
# - HTTP/2.4 is a lower version than HTTP/2.13, which in
# turn is lower than HTTP/12.3;
# - Leading zeros MUST be ignored by recipients.
if len(version_number) != 2:
raise ValueError
version_number = int(version_number[0]), int(version_number[1])
except (ValueError, IndexError):
self.send_error(400, "Bad request version (%r)" % version)
return False
if version_number >= (1, 1) and self.protocol_version >= "HTTP/1.1":
self.close_connection = 0
if version_number >= (2, 0):
self.send_error(505,
"Invalid HTTP Version (%s)" % base_version_number)
return False
elif len(words) == 2:
[command, path] = words
self.close_connection = 1
if command != 'GET':
self.send_error(400,
"Bad HTTP/0.9 request type (%r)" % command)
return False
elif not words:
return False
else:
self.send_error(400, "Bad request syntax (%r)" % requestline)
return False
self.command, self.path, self.request_version = command, path, version
# Examine the headers and look for a Connection directive
self.headers = self.MessageClass(self.rfile, 0)
conntype = self.headers.get('Connection', "")
if conntype.lower() == 'close':
self.close_connection = 1
elif (conntype.lower() == 'keep-alive' and
self.protocol_version >= "HTTP/1.1"):
self.close_connection = 0
return True
def handle_one_request(self):
"""Handle a single HTTP request.
You normally don't need to override this method; see the class
__doc__ string for information on how to handle specific HTTP
commands such as GET and POST.
"""
try:
self.raw_requestline = self.rfile.readline()
if not self.raw_requestline:
self.close_connection = 1
return
if not self.parse_request():
# An error code has been sent, just exit
return
mname = 'do_' + self.command
if not hasattr(self, mname):
self.send_error(501, "Unsupported method (%r)" % self.command)
return
method = getattr(self, mname)
method()
self.wfile.flush() #actually send the response if not already done.
except socket.timeout, e:
#a read or a write timed out. Discard this connection
self.log_error("Request timed out: %r", e)
self.close_connection = 1
return
def handle(self):
"""Handle multiple requests if necessary."""
self.close_connection = 1
self.handle_one_request()
while not self.close_connection:
self.handle_one_request()
def send_error(self, code, message=None):
"""Send and log an error reply.
Arguments are the error code, and a detailed message.
The detailed message defaults to the short entry matching the
response code.
This sends an error response (so it must be called before any
output has been generated), logs the error, and finally sends
a piece of HTML explaining the error to the user.
"""
try:
short, long = self.responses[code]
except KeyError:
short, long = '???', '???'
if message is None:
message = short
explain = long
self.log_error("code %d, message %s", code, message)
# using _quote_html to prevent Cross Site Scripting attacks (see bug #1100201)
content = (self.error_message_format %
{'code': code, 'message': _quote_html(message), 'explain': explain})
self.send_response(code, message)
self.send_header("Content-Type", self.error_content_type)
self.send_header('Connection', 'close')
self.end_headers()
if self.command != 'HEAD' and code >= 200 and code not in (204, 304):
self.wfile.write(content)
error_message_format = DEFAULT_ERROR_MESSAGE
error_content_type = DEFAULT_ERROR_CONTENT_TYPE
def send_response(self, code, message=None):
"""Send the response header and log the response code.
Also send two standard headers with the server software
version and the current date.
"""
self.log_request(code)
if message is None:
if code in self.responses:
message = self.responses[code][0]
else:
message = ''
if self.request_version != 'HTTP/0.9':
self.wfile.write("%s %d %s\r\n" %
(self.protocol_version, code, message))
# print (self.protocol_version, code, message)
self.send_header('Server', self.version_string())
self.send_header('Date', self.date_time_string())
def send_header(self, keyword, value):
"""Send a MIME header."""
if self.request_version != 'HTTP/0.9':
self.wfile.write("%s: %s\r\n" % (keyword, value))
if keyword.lower() == 'connection':
if value.lower() == 'close':
self.close_connection = 1
elif value.lower() == 'keep-alive':
self.close_connection = 0
def end_headers(self):
"""Send the blank line ending the MIME headers."""
if self.request_version != 'HTTP/0.9':
self.wfile.write("\r\n")
def log_request(self, code='-', size='-'):
"""Log an accepted request.
This is called by send_response().
"""
self.log_message('"%s" %s %s',
self.requestline, str(code), str(size))
def log_error(self, format, *args):
"""Log an error.
This is called when a request cannot be fulfilled. By
default it passes the message on to log_message().
Arguments are the same as for log_message().
XXX This should go to the separate error log.
"""
self.log_message(format, *args)
def log_message(self, format, *args):
"""Log an arbitrary message.
This is used by all other logging functions. Override
it if you have specific logging wishes.
The first argument, FORMAT, is a format string for the
message to be logged. If the format string contains
any % escapes requiring parameters, they should be
specified as subsequent arguments (it's just like
printf!).
The client host and current date/time are prefixed to
every message.
"""
sys.stderr.write("%s - - [%s] %s\n" %
(self.address_string(),
self.log_date_time_string(),
format%args))
def version_string(self):
"""Return the server software version string."""
return self.server_version + ' ' + self.sys_version
def date_time_string(self, timestamp=None):
"""Return the current date and time formatted for a message header."""
if timestamp is None:
timestamp = time.time()
year, month, day, hh, mm, ss, wd, y, z = time.gmtime(timestamp)
s = "%s, %02d %3s %4d %02d:%02d:%02d GMT" % (
self.weekdayname[wd],
day, self.monthname[month], year,
hh, mm, ss)
return s
def log_date_time_string(self):
"""Return the current time formatted for logging."""
now = time.time()
year, month, day, hh, mm, ss, x, y, z = time.localtime(now)
s = "%02d/%3s/%04d %02d:%02d:%02d" % (
day, self.monthname[month], year, hh, mm, ss)
return s
weekdayname = ['Mon', 'Tue', 'Wed', 'Thu', 'Fri', 'Sat', 'Sun']
monthname = [None,
'Jan', 'Feb', 'Mar', 'Apr', 'May', 'Jun',
'Jul', 'Aug', 'Sep', 'Oct', 'Nov', 'Dec']
def address_string(self):
"""Return the client address formatted for logging.
This version looks up the full hostname using gethostbyaddr(),
and tries to find a name that contains at least one dot.
"""
host, port = self.client_address[:2]
return socket.getfqdn(host)
# Essentially static class variables
# The version of the HTTP protocol we support.
# Set this to HTTP/1.1 to enable automatic keepalive
protocol_version = "HTTP/1.0"
# The Message-like class used to parse headers
MessageClass = mimetools.Message
# Table mapping response codes to messages; entries have the
# form {code: (shortmessage, longmessage)}.
# See RFC 2616.
responses = {
100: ('Continue', 'Request received, please continue'),
101: ('Switching Protocols',
'Switching to new protocol; obey Upgrade header'),
200: ('OK', 'Request fulfilled, document follows'),
201: ('Created', 'Document created, URL follows'),
202: ('Accepted',
'Request accepted, processing continues off-line'),
203: ('Non-Authoritative Information', 'Request fulfilled from cache'),
204: ('No Content', 'Request fulfilled, nothing follows'),
205: ('Reset Content', 'Clear input form for further input.'),
206: ('Partial Content', 'Partial content follows.'),
300: ('Multiple Choices',
'Object has several resources -- see URI list'),
301: ('Moved Permanently', 'Object moved permanently -- see URI list'),
302: ('Found', 'Object moved temporarily -- see URI list'),
303: ('See Other', 'Object moved -- see Method and URL list'),
304: ('Not Modified',
'Document has not changed since given time'),
305: ('Use Proxy',
'You must use proxy specified in Location to access this '
'resource.'),
307: ('Temporary Redirect',
'Object moved temporarily -- see URI list'),
400: ('Bad Request',
'Bad request syntax or unsupported method'),
401: ('Unauthorized',
'No permission -- see authorization schemes'),
402: ('Payment Required',
'No payment -- see charging schemes'),
403: ('Forbidden',
'Request forbidden -- authorization will not help'),
404: ('Not Found', 'Nothing matches the given URI'),
405: ('Method Not Allowed',
'Specified method is invalid for this resource.'),
406: ('Not Acceptable', 'URI not available in preferred format.'),
407: ('Proxy Authentication Required', 'You must authenticate with '
'this proxy before proceeding.'),
408: ('Request Timeout', 'Request timed out; try again later.'),
409: ('Conflict', 'Request conflict.'),
410: ('Gone',
'URI no longer exists and has been permanently removed.'),
411: ('Length Required', 'Client must specify Content-Length.'),
412: ('Precondition Failed', 'Precondition in headers is false.'),
413: ('Request Entity Too Large', 'Entity is too large.'),
414: ('Request-URI Too Long', 'URI is too long.'),
415: ('Unsupported Media Type', 'Entity body in unsupported format.'),
416: ('Requested Range Not Satisfiable',
'Cannot satisfy request range.'),
417: ('Expectation Failed',
'Expect condition could not be satisfied.'),
500: ('Internal Server Error', 'Server got itself in trouble'),
501: ('Not Implemented',
'Server does not support this operation'),
502: ('Bad Gateway', 'Invalid responses from another server/proxy.'),
503: ('Service Unavailable',
'The server cannot process the request due to a high load'),
504: ('Gateway Timeout',
'The gateway server did not receive a timely response'),
505: ('HTTP Version Not Supported', 'Cannot fulfill request.'),
}
def test(HandlerClass = BaseHTTPRequestHandler,
ServerClass = HTTPServer, protocol="HTTP/1.0"):
"""Test the HTTP request handler class.
This runs an HTTP server on port 8000 (or the first command line
argument).
"""
if sys.argv[1:]:
port = int(sys.argv[1])
else:
port = 8000
server_address = ('', port)
HandlerClass.protocol_version = protocol
httpd = ServerClass(server_address, HandlerClass)
sa = httpd.socket.getsockname()
print "Serving HTTP on", sa[0], "port", sa[1], "..."
httpd.serve_forever()
if __name__ == '__main__':
test()

View File

@@ -0,0 +1,180 @@
"""Bastionification utility.
A bastion (for another object -- the 'original') is an object that has
the same methods as the original but does not give access to its
instance variables. Bastions have a number of uses, but the most
obvious one is to provide code executing in restricted mode with a
safe interface to an object implemented in unrestricted mode.
The bastionification routine has an optional second argument which is
a filter function. Only those methods for which the filter method
(called with the method name as argument) returns true are accessible.
The default filter method returns true unless the method name begins
with an underscore.
There are a number of possible implementations of bastions. We use a
'lazy' approach where the bastion's __getattr__() discipline does all
the work for a particular method the first time it is used. This is
usually fastest, especially if the user doesn't call all available
methods. The retrieved methods are stored as instance variables of
the bastion, so the overhead is only occurred on the first use of each
method.
Detail: the bastion class has a __repr__() discipline which includes
the repr() of the original object. This is precomputed when the
bastion is created.
"""
from warnings import warnpy3k
warnpy3k("the Bastion module has been removed in Python 3.0", stacklevel=2)
del warnpy3k
__all__ = ["BastionClass", "Bastion"]
from types import MethodType
class BastionClass:
"""Helper class used by the Bastion() function.
You could subclass this and pass the subclass as the bastionclass
argument to the Bastion() function, as long as the constructor has
the same signature (a get() function and a name for the object).
"""
def __init__(self, get, name):
"""Constructor.
Arguments:
get - a function that gets the attribute value (by name)
name - a human-readable name for the original object
(suggestion: use repr(object))
"""
self._get_ = get
self._name_ = name
def __repr__(self):
"""Return a representation string.
This includes the name passed in to the constructor, so that
if you print the bastion during debugging, at least you have
some idea of what it is.
"""
return "<Bastion for %s>" % self._name_
def __getattr__(self, name):
"""Get an as-yet undefined attribute value.
This calls the get() function that was passed to the
constructor. The result is stored as an instance variable so
that the next time the same attribute is requested,
__getattr__() won't be invoked.
If the get() function raises an exception, this is simply
passed on -- exceptions are not cached.
"""
attribute = self._get_(name)
self.__dict__[name] = attribute
return attribute
def Bastion(object, filter = lambda name: name[:1] != '_',
name=None, bastionclass=BastionClass):
"""Create a bastion for an object, using an optional filter.
See the Bastion module's documentation for background.
Arguments:
object - the original object
filter - a predicate that decides whether a function name is OK;
by default all names are OK that don't start with '_'
name - the name of the object; default repr(object)
bastionclass - class used to create the bastion; default BastionClass
"""
raise RuntimeError, "This code is not secure in Python 2.2 and later"
# Note: we define *two* ad-hoc functions here, get1 and get2.
# Both are intended to be called in the same way: get(name).
# It is clear that the real work (getting the attribute
# from the object and calling the filter) is done in get1.
# Why can't we pass get1 to the bastion? Because the user
# would be able to override the filter argument! With get2,
# overriding the default argument is no security loophole:
# all it does is call it.
# Also notice that we can't place the object and filter as
# instance variables on the bastion object itself, since
# the user has full access to all instance variables!
def get1(name, object=object, filter=filter):
"""Internal function for Bastion(). See source comments."""
if filter(name):
attribute = getattr(object, name)
if type(attribute) == MethodType:
return attribute
raise AttributeError, name
def get2(name, get1=get1):
"""Internal function for Bastion(). See source comments."""
return get1(name)
if name is None:
name = repr(object)
return bastionclass(get2, name)
def _test():
"""Test the Bastion() function."""
class Original:
def __init__(self):
self.sum = 0
def add(self, n):
self._add(n)
def _add(self, n):
self.sum = self.sum + n
def total(self):
return self.sum
o = Original()
b = Bastion(o)
testcode = """if 1:
b.add(81)
b.add(18)
print "b.total() =", b.total()
try:
print "b.sum =", b.sum,
except:
print "inaccessible"
else:
print "accessible"
try:
print "b._add =", b._add,
except:
print "inaccessible"
else:
print "accessible"
try:
print "b._get_.func_defaults =", map(type, b._get_.func_defaults),
except:
print "inaccessible"
else:
print "accessible"
\n"""
exec testcode
print '='*20, "Using rexec:", '='*20
import rexec
r = rexec.RExec()
m = r.add_module('__main__')
m.b = b
r.r_exec(testcode)
if __name__ == '__main__':
_test()

View File

@@ -0,0 +1,374 @@
"""CGI-savvy HTTP Server.
This module builds on SimpleHTTPServer by implementing GET and POST
requests to cgi-bin scripts.
If the os.fork() function is not present (e.g. on Windows),
os.popen2() is used as a fallback, with slightly altered semantics; if
that function is not present either (e.g. on Macintosh), only Python
scripts are supported, and they are executed by the current process.
In all cases, the implementation is intentionally naive -- all
requests are executed sychronously.
SECURITY WARNING: DON'T USE THIS CODE UNLESS YOU ARE INSIDE A FIREWALL
-- it may execute arbitrary Python code or external programs.
Note that status code 200 is sent prior to execution of a CGI script, so
scripts cannot send other status codes such as 302 (redirect).
"""
__version__ = "0.4"
__all__ = ["CGIHTTPRequestHandler"]
import os
import sys
import urllib
import BaseHTTPServer
import SimpleHTTPServer
import select
import copy
class CGIHTTPRequestHandler(SimpleHTTPServer.SimpleHTTPRequestHandler):
"""Complete HTTP server with GET, HEAD and POST commands.
GET and HEAD also support running CGI scripts.
The POST command is *only* implemented for CGI scripts.
"""
# Determine platform specifics
have_fork = hasattr(os, 'fork')
have_popen2 = hasattr(os, 'popen2')
have_popen3 = hasattr(os, 'popen3')
# Make rfile unbuffered -- we need to read one line and then pass
# the rest to a subprocess, so we can't use buffered input.
rbufsize = 0
def do_POST(self):
"""Serve a POST request.
This is only implemented for CGI scripts.
"""
if self.is_cgi():
self.run_cgi()
else:
self.send_error(501, "Can only POST to CGI scripts")
def send_head(self):
"""Version of send_head that support CGI scripts"""
if self.is_cgi():
return self.run_cgi()
else:
return SimpleHTTPServer.SimpleHTTPRequestHandler.send_head(self)
def is_cgi(self):
"""Test whether self.path corresponds to a CGI script.
Returns True and updates the cgi_info attribute to the tuple
(dir, rest) if self.path requires running a CGI script.
Returns False otherwise.
If any exception is raised, the caller should assume that
self.path was rejected as invalid and act accordingly.
The default implementation tests whether the normalized url
path begins with one of the strings in self.cgi_directories
(and the next character is a '/' or the end of the string).
"""
splitpath = _url_collapse_path_split(self.path)
if splitpath[0] in self.cgi_directories:
self.cgi_info = splitpath
return True
return False
cgi_directories = ['/cgi-bin', '/htbin']
def is_executable(self, path):
"""Test whether argument path is an executable file."""
return executable(path)
def is_python(self, path):
"""Test whether argument path is a Python script."""
head, tail = os.path.splitext(path)
return tail.lower() in (".py", ".pyw")
def run_cgi(self):
"""Execute a CGI script."""
path = self.path
dir, rest = self.cgi_info
i = path.find('/', len(dir) + 1)
while i >= 0:
nextdir = path[:i]
nextrest = path[i+1:]
scriptdir = self.translate_path(nextdir)
if os.path.isdir(scriptdir):
dir, rest = nextdir, nextrest
i = path.find('/', len(dir) + 1)
else:
break
# find an explicit query string, if present.
i = rest.rfind('?')
if i >= 0:
rest, query = rest[:i], rest[i+1:]
else:
query = ''
# dissect the part after the directory name into a script name &
# a possible additional path, to be stored in PATH_INFO.
i = rest.find('/')
if i >= 0:
script, rest = rest[:i], rest[i:]
else:
script, rest = rest, ''
scriptname = dir + '/' + script
scriptfile = self.translate_path(scriptname)
if not os.path.exists(scriptfile):
self.send_error(404, "No such CGI script (%r)" % scriptname)
return
if not os.path.isfile(scriptfile):
self.send_error(403, "CGI script is not a plain file (%r)" %
scriptname)
return
ispy = self.is_python(scriptname)
if not ispy:
if not (self.have_fork or self.have_popen2 or self.have_popen3):
self.send_error(403, "CGI script is not a Python script (%r)" %
scriptname)
return
if not self.is_executable(scriptfile):
self.send_error(403, "CGI script is not executable (%r)" %
scriptname)
return
# Reference: http://hoohoo.ncsa.uiuc.edu/cgi/env.html
# XXX Much of the following could be prepared ahead of time!
env = copy.deepcopy(os.environ)
env['SERVER_SOFTWARE'] = self.version_string()
env['SERVER_NAME'] = self.server.server_name
env['GATEWAY_INTERFACE'] = 'CGI/1.1'
env['SERVER_PROTOCOL'] = self.protocol_version
env['SERVER_PORT'] = str(self.server.server_port)
env['REQUEST_METHOD'] = self.command
uqrest = urllib.unquote(rest)
env['PATH_INFO'] = uqrest
env['PATH_TRANSLATED'] = self.translate_path(uqrest)
env['SCRIPT_NAME'] = scriptname
if query:
env['QUERY_STRING'] = query
host = self.address_string()
if host != self.client_address[0]:
env['REMOTE_HOST'] = host
env['REMOTE_ADDR'] = self.client_address[0]
authorization = self.headers.getheader("authorization")
if authorization:
authorization = authorization.split()
if len(authorization) == 2:
import base64, binascii
env['AUTH_TYPE'] = authorization[0]
if authorization[0].lower() == "basic":
try:
authorization = base64.decodestring(authorization[1])
except binascii.Error:
pass
else:
authorization = authorization.split(':')
if len(authorization) == 2:
env['REMOTE_USER'] = authorization[0]
# XXX REMOTE_IDENT
if self.headers.typeheader is None:
env['CONTENT_TYPE'] = self.headers.type
else:
env['CONTENT_TYPE'] = self.headers.typeheader
length = self.headers.getheader('content-length')
if length:
env['CONTENT_LENGTH'] = length
referer = self.headers.getheader('referer')
if referer:
env['HTTP_REFERER'] = referer
accept = []
for line in self.headers.getallmatchingheaders('accept'):
if line[:1] in "\t\n\r ":
accept.append(line.strip())
else:
accept = accept + line[7:].split(',')
env['HTTP_ACCEPT'] = ','.join(accept)
ua = self.headers.getheader('user-agent')
if ua:
env['HTTP_USER_AGENT'] = ua
co = filter(None, self.headers.getheaders('cookie'))
if co:
env['HTTP_COOKIE'] = ', '.join(co)
# XXX Other HTTP_* headers
# Since we're setting the env in the parent, provide empty
# values to override previously set values
for k in ('QUERY_STRING', 'REMOTE_HOST', 'CONTENT_LENGTH',
'HTTP_USER_AGENT', 'HTTP_COOKIE', 'HTTP_REFERER'):
env.setdefault(k, "")
self.send_response(200, "Script output follows")
decoded_query = query.replace('+', ' ')
if self.have_fork:
# Unix -- fork as we should
args = [script]
if '=' not in decoded_query:
args.append(decoded_query)
nobody = nobody_uid()
self.wfile.flush() # Always flush before forking
pid = os.fork()
if pid != 0:
# Parent
pid, sts = os.waitpid(pid, 0)
# throw away additional data [see bug #427345]
while select.select([self.rfile], [], [], 0)[0]:
if not self.rfile.read(1):
break
if sts:
self.log_error("CGI script exit status %#x", sts)
return
# Child
try:
try:
os.setuid(nobody)
except os.error:
pass
os.dup2(self.rfile.fileno(), 0)
os.dup2(self.wfile.fileno(), 1)
os.execve(scriptfile, args, env)
except:
self.server.handle_error(self.request, self.client_address)
os._exit(127)
else:
# Non Unix - use subprocess
import subprocess
cmdline = [scriptfile]
if self.is_python(scriptfile):
interp = sys.executable
if interp.lower().endswith("w.exe"):
# On Windows, use python.exe, not pythonw.exe
interp = interp[:-5] + interp[-4:]
cmdline = [interp, '-u'] + cmdline
if '=' not in query:
cmdline.append(query)
self.log_message("command: %s", subprocess.list2cmdline(cmdline))
try:
nbytes = int(length)
except (TypeError, ValueError):
nbytes = 0
p = subprocess.Popen(cmdline,
stdin = subprocess.PIPE,
stdout = subprocess.PIPE,
stderr = subprocess.PIPE,
env = env
)
if self.command.lower() == "post" and nbytes > 0:
data = self.rfile.read(nbytes)
else:
data = None
# throw away additional data [see bug #427345]
while select.select([self.rfile._sock], [], [], 0)[0]:
if not self.rfile._sock.recv(1):
break
stdout, stderr = p.communicate(data)
self.wfile.write(stdout)
if stderr:
self.log_error('%s', stderr)
p.stderr.close()
p.stdout.close()
status = p.returncode
if status:
self.log_error("CGI script exit status %#x", status)
else:
self.log_message("CGI script exited OK")
# TODO(gregory.p.smith): Move this into an appropriate library.
def _url_collapse_path_split(path):
"""
Given a URL path, remove extra '/'s and '.' path elements and collapse
any '..' references.
Implements something akin to RFC-2396 5.2 step 6 to parse relative paths.
Returns: A tuple of (head, tail) where tail is everything after the final /
and head is everything before it. Head will always start with a '/' and,
if it contains anything else, never have a trailing '/'.
Raises: IndexError if too many '..' occur within the path.
"""
# Similar to os.path.split(os.path.normpath(path)) but specific to URL
# path semantics rather than local operating system semantics.
path_parts = []
for part in path.split('/'):
if part == '.':
path_parts.append('')
else:
path_parts.append(part)
# Filter out blank non trailing parts before consuming the '..'.
path_parts = [part for part in path_parts[:-1] if part] + path_parts[-1:]
if path_parts:
tail_part = path_parts.pop()
else:
tail_part = ''
head_parts = []
for part in path_parts:
if part == '..':
head_parts.pop()
else:
head_parts.append(part)
if tail_part and tail_part == '..':
head_parts.pop()
tail_part = ''
return ('/' + '/'.join(head_parts), tail_part)
nobody = None
def nobody_uid():
"""Internal routine to get nobody's uid"""
global nobody
if nobody:
return nobody
try:
import pwd
except ImportError:
return -1
try:
nobody = pwd.getpwnam('nobody')[2]
except KeyError:
nobody = 1 + max(map(lambda x: x[2], pwd.getpwall()))
return nobody
def executable(path):
"""Test for executable file."""
try:
st = os.stat(path)
except os.error:
return False
return st.st_mode & 0111 != 0
def test(HandlerClass = CGIHTTPRequestHandler,
ServerClass = BaseHTTPServer.HTTPServer):
SimpleHTTPServer.test(HandlerClass, ServerClass)
if __name__ == '__main__':
test()

View File

@@ -0,0 +1,711 @@
"""Configuration file parser.
A setup file consists of sections, lead by a "[section]" header,
and followed by "name: value" entries, with continuations and such in
the style of RFC 822.
The option values can contain format strings which refer to other values in
the same section, or values in a special [DEFAULT] section.
For example:
something: %(dir)s/whatever
would resolve the "%(dir)s" to the value of dir. All reference
expansions are done late, on demand.
Intrinsic defaults can be specified by passing them into the
ConfigParser constructor as a dictionary.
class:
ConfigParser -- responsible for parsing a list of
configuration files, and managing the parsed database.
methods:
__init__(defaults=None)
create the parser and specify a dictionary of intrinsic defaults. The
keys must be strings, the values must be appropriate for %()s string
interpolation. Note that `__name__' is always an intrinsic default;
its value is the section's name.
sections()
return all the configuration section names, sans DEFAULT
has_section(section)
return whether the given section exists
has_option(section, option)
return whether the given option exists in the given section
options(section)
return list of configuration options for the named section
read(filenames)
read and parse the list of named configuration files, given by
name. A single filename is also allowed. Non-existing files
are ignored. Return list of successfully read files.
readfp(fp, filename=None)
read and parse one configuration file, given as a file object.
The filename defaults to fp.name; it is only used in error
messages (if fp has no `name' attribute, the string `<???>' is used).
get(section, option, raw=False, vars=None)
return a string value for the named option. All % interpolations are
expanded in the return values, based on the defaults passed into the
constructor and the DEFAULT section. Additional substitutions may be
provided using the `vars' argument, which must be a dictionary whose
contents override any pre-existing defaults.
getint(section, options)
like get(), but convert value to an integer
getfloat(section, options)
like get(), but convert value to a float
getboolean(section, options)
like get(), but convert value to a boolean (currently case
insensitively defined as 0, false, no, off for False, and 1, true,
yes, on for True). Returns False or True.
items(section, raw=False, vars=None)
return a list of tuples with (name, value) for each option
in the section.
remove_section(section)
remove the given file section and all its options
remove_option(section, option)
remove the given option from the given section
set(section, option, value)
set the given option
write(fp)
write the configuration state in .ini format
"""
try:
from collections import OrderedDict as _default_dict
except ImportError:
# fallback for setup.py which hasn't yet built _collections
_default_dict = dict
import re
__all__ = ["NoSectionError", "DuplicateSectionError", "NoOptionError",
"InterpolationError", "InterpolationDepthError",
"InterpolationSyntaxError", "ParsingError",
"MissingSectionHeaderError",
"ConfigParser", "SafeConfigParser", "RawConfigParser",
"DEFAULTSECT", "MAX_INTERPOLATION_DEPTH"]
DEFAULTSECT = "DEFAULT"
MAX_INTERPOLATION_DEPTH = 10
# exception classes
class Error(Exception):
"""Base class for ConfigParser exceptions."""
def _get_message(self):
"""Getter for 'message'; needed only to override deprecation in
BaseException."""
return self.__message
def _set_message(self, value):
"""Setter for 'message'; needed only to override deprecation in
BaseException."""
self.__message = value
# BaseException.message has been deprecated since Python 2.6. To prevent
# DeprecationWarning from popping up over this pre-existing attribute, use
# a new property that takes lookup precedence.
message = property(_get_message, _set_message)
def __init__(self, msg=''):
self.message = msg
Exception.__init__(self, msg)
def __repr__(self):
return self.message
__str__ = __repr__
class NoSectionError(Error):
"""Raised when no section matches a requested option."""
def __init__(self, section):
Error.__init__(self, 'No section: %r' % (section,))
self.section = section
class DuplicateSectionError(Error):
"""Raised when a section is multiply-created."""
def __init__(self, section):
Error.__init__(self, "Section %r already exists" % section)
self.section = section
class NoOptionError(Error):
"""A requested option was not found."""
def __init__(self, option, section):
Error.__init__(self, "No option %r in section: %r" %
(option, section))
self.option = option
self.section = section
class InterpolationError(Error):
"""Base class for interpolation-related exceptions."""
def __init__(self, option, section, msg):
Error.__init__(self, msg)
self.option = option
self.section = section
class InterpolationMissingOptionError(InterpolationError):
"""A string substitution required a setting which was not available."""
def __init__(self, option, section, rawval, reference):
msg = ("Bad value substitution:\n"
"\tsection: [%s]\n"
"\toption : %s\n"
"\tkey : %s\n"
"\trawval : %s\n"
% (section, option, reference, rawval))
InterpolationError.__init__(self, option, section, msg)
self.reference = reference
class InterpolationSyntaxError(InterpolationError):
"""Raised when the source text into which substitutions are made
does not conform to the required syntax."""
class InterpolationDepthError(InterpolationError):
"""Raised when substitutions are nested too deeply."""
def __init__(self, option, section, rawval):
msg = ("Value interpolation too deeply recursive:\n"
"\tsection: [%s]\n"
"\toption : %s\n"
"\trawval : %s\n"
% (section, option, rawval))
InterpolationError.__init__(self, option, section, msg)
class ParsingError(Error):
"""Raised when a configuration file does not follow legal syntax."""
def __init__(self, filename):
Error.__init__(self, 'File contains parsing errors: %s' % filename)
self.filename = filename
self.errors = []
def append(self, lineno, line):
self.errors.append((lineno, line))
self.message += '\n\t[line %2d]: %s' % (lineno, line)
class MissingSectionHeaderError(ParsingError):
"""Raised when a key-value pair is found before any section header."""
def __init__(self, filename, lineno, line):
Error.__init__(
self,
'File contains no section headers.\nfile: %s, line: %d\n%r' %
(filename, lineno, line))
self.filename = filename
self.lineno = lineno
self.line = line
class RawConfigParser:
def __init__(self, defaults=None, dict_type=_default_dict,
allow_no_value=False):
self._dict = dict_type
self._sections = self._dict()
self._defaults = self._dict()
if allow_no_value:
self._optcre = self.OPTCRE_NV
else:
self._optcre = self.OPTCRE
if defaults:
for key, value in defaults.items():
self._defaults[self.optionxform(key)] = value
def defaults(self):
return self._defaults
def sections(self):
"""Return a list of section names, excluding [DEFAULT]"""
# self._sections will never have [DEFAULT] in it
return self._sections.keys()
def add_section(self, section):
"""Create a new section in the configuration.
Raise DuplicateSectionError if a section by the specified name
already exists. Raise ValueError if name is DEFAULT or any of it's
case-insensitive variants.
"""
if section.lower() == "default":
raise ValueError, 'Invalid section name: %s' % section
if section in self._sections:
raise DuplicateSectionError(section)
self._sections[section] = self._dict()
def has_section(self, section):
"""Indicate whether the named section is present in the configuration.
The DEFAULT section is not acknowledged.
"""
return section in self._sections
def options(self, section):
"""Return a list of option names for the given section name."""
try:
opts = self._sections[section].copy()
except KeyError:
raise NoSectionError(section)
opts.update(self._defaults)
if '__name__' in opts:
del opts['__name__']
return opts.keys()
def read(self, filenames):
"""Read and parse a filename or a list of filenames.
Files that cannot be opened are silently ignored; this is
designed so that you can specify a list of potential
configuration file locations (e.g. current directory, user's
home directory, systemwide directory), and all existing
configuration files in the list will be read. A single
filename may also be given.
Return list of successfully read files.
"""
if isinstance(filenames, basestring):
filenames = [filenames]
read_ok = []
for filename in filenames:
try:
fp = open(filename)
except IOError:
continue
self._read(fp, filename)
fp.close()
read_ok.append(filename)
return read_ok
def readfp(self, fp, filename=None):
"""Like read() but the argument must be a file-like object.
The `fp' argument must have a `readline' method. Optional
second argument is the `filename', which if not given, is
taken from fp.name. If fp has no `name' attribute, `<???>' is
used.
"""
if filename is None:
try:
filename = fp.name
except AttributeError:
filename = '<???>'
self._read(fp, filename)
def get(self, section, option):
opt = self.optionxform(option)
if section not in self._sections:
if section != DEFAULTSECT:
raise NoSectionError(section)
if opt in self._defaults:
return self._defaults[opt]
else:
raise NoOptionError(option, section)
elif opt in self._sections[section]:
return self._sections[section][opt]
elif opt in self._defaults:
return self._defaults[opt]
else:
raise NoOptionError(option, section)
def items(self, section):
try:
d2 = self._sections[section]
except KeyError:
if section != DEFAULTSECT:
raise NoSectionError(section)
d2 = self._dict()
d = self._defaults.copy()
d.update(d2)
if "__name__" in d:
del d["__name__"]
return d.items()
def _get(self, section, conv, option):
return conv(self.get(section, option))
def getint(self, section, option):
return self._get(section, int, option)
def getfloat(self, section, option):
return self._get(section, float, option)
_boolean_states = {'1': True, 'yes': True, 'true': True, 'on': True,
'0': False, 'no': False, 'false': False, 'off': False}
def getboolean(self, section, option):
v = self.get(section, option)
if v.lower() not in self._boolean_states:
raise ValueError, 'Not a boolean: %s' % v
return self._boolean_states[v.lower()]
def optionxform(self, optionstr):
return optionstr.lower()
def has_option(self, section, option):
"""Check for the existence of a given option in a given section."""
if not section or section == DEFAULTSECT:
option = self.optionxform(option)
return option in self._defaults
elif section not in self._sections:
return False
else:
option = self.optionxform(option)
return (option in self._sections[section]
or option in self._defaults)
def set(self, section, option, value=None):
"""Set an option."""
if not section or section == DEFAULTSECT:
sectdict = self._defaults
else:
try:
sectdict = self._sections[section]
except KeyError:
raise NoSectionError(section)
sectdict[self.optionxform(option)] = value
def write(self, fp):
"""Write an .ini-format representation of the configuration state."""
if self._defaults:
fp.write("[%s]\n" % DEFAULTSECT)
for (key, value) in self._defaults.items():
fp.write("%s = %s\n" % (key, str(value).replace('\n', '\n\t')))
fp.write("\n")
for section in self._sections:
fp.write("[%s]\n" % section)
for (key, value) in self._sections[section].items():
if key == "__name__":
continue
if (value is not None) or (self._optcre == self.OPTCRE):
key = " = ".join((key, str(value).replace('\n', '\n\t')))
fp.write("%s\n" % (key))
fp.write("\n")
def remove_option(self, section, option):
"""Remove an option."""
if not section or section == DEFAULTSECT:
sectdict = self._defaults
else:
try:
sectdict = self._sections[section]
except KeyError:
raise NoSectionError(section)
option = self.optionxform(option)
existed = option in sectdict
if existed:
del sectdict[option]
return existed
def remove_section(self, section):
"""Remove a file section."""
existed = section in self._sections
if existed:
del self._sections[section]
return existed
#
# Regular expressions for parsing section headers and options.
#
SECTCRE = re.compile(
r'\[' # [
r'(?P<header>[^]]+)' # very permissive!
r'\]' # ]
)
OPTCRE = re.compile(
r'(?P<option>[^:=\s][^:=]*)' # very permissive!
r'\s*(?P<vi>[:=])\s*' # any number of space/tab,
# followed by separator
# (either : or =), followed
# by any # space/tab
r'(?P<value>.*)$' # everything up to eol
)
OPTCRE_NV = re.compile(
r'(?P<option>[^:=\s][^:=]*)' # very permissive!
r'\s*(?:' # any number of space/tab,
r'(?P<vi>[:=])\s*' # optionally followed by
# separator (either : or
# =), followed by any #
# space/tab
r'(?P<value>.*))?$' # everything up to eol
)
def _read(self, fp, fpname):
"""Parse a sectioned setup file.
The sections in setup file contains a title line at the top,
indicated by a name in square brackets (`[]'), plus key/value
options lines, indicated by `name: value' format lines.
Continuations are represented by an embedded newline then
leading whitespace. Blank lines, lines beginning with a '#',
and just about everything else are ignored.
"""
cursect = None # None, or a dictionary
optname = None
lineno = 0
e = None # None, or an exception
while True:
line = fp.readline()
if not line:
break
lineno = lineno + 1
# comment or blank line?
if line.strip() == '' or line[0] in '#;':
continue
if line.split(None, 1)[0].lower() == 'rem' and line[0] in "rR":
# no leading whitespace
continue
# continuation line?
if line[0].isspace() and cursect is not None and optname:
value = line.strip()
if value:
cursect[optname].append(value)
# a section header or option header?
else:
# is it a section header?
mo = self.SECTCRE.match(line)
if mo:
sectname = mo.group('header')
if sectname in self._sections:
cursect = self._sections[sectname]
elif sectname == DEFAULTSECT:
cursect = self._defaults
else:
cursect = self._dict()
cursect['__name__'] = sectname
self._sections[sectname] = cursect
# So sections can't start with a continuation line
optname = None
# no section header in the file?
elif cursect is None:
raise MissingSectionHeaderError(fpname, lineno, line)
# an option line?
else:
mo = self._optcre.match(line)
if mo:
optname, vi, optval = mo.group('option', 'vi', 'value')
optname = self.optionxform(optname.rstrip())
# This check is fine because the OPTCRE cannot
# match if it would set optval to None
if optval is not None:
if vi in ('=', ':') and ';' in optval:
# ';' is a comment delimiter only if it follows
# a spacing character
pos = optval.find(';')
if pos != -1 and optval[pos-1].isspace():
optval = optval[:pos]
optval = optval.strip()
# allow empty values
if optval == '""':
optval = ''
cursect[optname] = [optval]
else:
# valueless option handling
cursect[optname] = optval
else:
# a non-fatal parsing error occurred. set up the
# exception but keep going. the exception will be
# raised at the end of the file and will contain a
# list of all bogus lines
if not e:
e = ParsingError(fpname)
e.append(lineno, repr(line))
# if any parsing errors occurred, raise an exception
if e:
raise e
# join the multi-line values collected while reading
all_sections = [self._defaults]
all_sections.extend(self._sections.values())
for options in all_sections:
for name, val in options.items():
if isinstance(val, list):
options[name] = '\n'.join(val)
class ConfigParser(RawConfigParser):
def get(self, section, option, raw=False, vars=None):
"""Get an option value for a given section.
If `vars' is provided, it must be a dictionary. The option is looked up
in `vars' (if provided), `section', and in `defaults' in that order.
All % interpolations are expanded in the return values, unless the
optional argument `raw' is true. Values for interpolation keys are
looked up in the same manner as the option.
The section DEFAULT is special.
"""
d = self._defaults.copy()
try:
d.update(self._sections[section])
except KeyError:
if section != DEFAULTSECT:
raise NoSectionError(section)
# Update with the entry specific variables
if vars:
for key, value in vars.items():
d[self.optionxform(key)] = value
option = self.optionxform(option)
try:
value = d[option]
except KeyError:
raise NoOptionError(option, section)
if raw or value is None:
return value
else:
return self._interpolate(section, option, value, d)
def items(self, section, raw=False, vars=None):
"""Return a list of tuples with (name, value) for each option
in the section.
All % interpolations are expanded in the return values, based on the
defaults passed into the constructor, unless the optional argument
`raw' is true. Additional substitutions may be provided using the
`vars' argument, which must be a dictionary whose contents overrides
any pre-existing defaults.
The section DEFAULT is special.
"""
d = self._defaults.copy()
try:
d.update(self._sections[section])
except KeyError:
if section != DEFAULTSECT:
raise NoSectionError(section)
# Update with the entry specific variables
if vars:
for key, value in vars.items():
d[self.optionxform(key)] = value
options = d.keys()
if "__name__" in options:
options.remove("__name__")
if raw:
return [(option, d[option])
for option in options]
else:
return [(option, self._interpolate(section, option, d[option], d))
for option in options]
def _interpolate(self, section, option, rawval, vars):
# do the string interpolation
value = rawval
depth = MAX_INTERPOLATION_DEPTH
while depth: # Loop through this until it's done
depth -= 1
if value and "%(" in value:
value = self._KEYCRE.sub(self._interpolation_replace, value)
try:
value = value % vars
except KeyError, e:
raise InterpolationMissingOptionError(
option, section, rawval, e.args[0])
else:
break
if value and "%(" in value:
raise InterpolationDepthError(option, section, rawval)
return value
_KEYCRE = re.compile(r"%\(([^)]*)\)s|.")
def _interpolation_replace(self, match):
s = match.group(1)
if s is None:
return match.group()
else:
return "%%(%s)s" % self.optionxform(s)
class SafeConfigParser(ConfigParser):
def _interpolate(self, section, option, rawval, vars):
# do the string interpolation
L = []
self._interpolate_some(option, L, rawval, section, vars, 1)
return ''.join(L)
_interpvar_re = re.compile(r"%\(([^)]+)\)s")
def _interpolate_some(self, option, accum, rest, section, map, depth):
if depth > MAX_INTERPOLATION_DEPTH:
raise InterpolationDepthError(option, section, rest)
while rest:
p = rest.find("%")
if p < 0:
accum.append(rest)
return
if p > 0:
accum.append(rest[:p])
rest = rest[p:]
# p is no longer used
c = rest[1:2]
if c == "%":
accum.append("%")
rest = rest[2:]
elif c == "(":
m = self._interpvar_re.match(rest)
if m is None:
raise InterpolationSyntaxError(option, section,
"bad interpolation variable reference %r" % rest)
var = self.optionxform(m.group(1))
rest = rest[m.end():]
try:
v = map[var]
except KeyError:
raise InterpolationMissingOptionError(
option, section, rest, var)
if "%" in v:
self._interpolate_some(option, accum, v,
section, map, depth + 1)
else:
accum.append(v)
else:
raise InterpolationSyntaxError(
option, section,
"'%%' must be followed by '%%' or '(', found: %r" % (rest,))
def set(self, section, option, value=None):
"""Set an option. Extend ConfigParser.set: check for string values."""
# The only legal non-string value if we allow valueless
# options is None, so we need to check if the value is a
# string if:
# - we do not allow valueless options, or
# - we allow valueless options but the value is not None
if self._optcre is self.OPTCRE or value:
if not isinstance(value, basestring):
raise TypeError("option values must be strings")
if value is not None:
# check for bad percent signs:
# first, replace all "good" interpolations
tmp_value = value.replace('%%', '')
tmp_value = self._interpvar_re.sub('', tmp_value)
# then, check if there's a lone percent sign left
if '%' in tmp_value:
raise ValueError("invalid interpolation syntax in %r at "
"position %d" % (value, tmp_value.find('%')))
ConfigParser.set(self, section, option, value)

756
src/python/stdlib/Cookie.py Normal file
View File

@@ -0,0 +1,756 @@
#!/usr/bin/env python
#
####
# Copyright 2000 by Timothy O'Malley <timo@alum.mit.edu>
#
# All Rights Reserved
#
# Permission to use, copy, modify, and distribute this software
# and its documentation for any purpose and without fee is hereby
# granted, provided that the above copyright notice appear in all
# copies and that both that copyright notice and this permission
# notice appear in supporting documentation, and that the name of
# Timothy O'Malley not be used in advertising or publicity
# pertaining to distribution of the software without specific, written
# prior permission.
#
# Timothy O'Malley DISCLAIMS ALL WARRANTIES WITH REGARD TO THIS
# SOFTWARE, INCLUDING ALL IMPLIED WARRANTIES OF MERCHANTABILITY
# AND FITNESS, IN NO EVENT SHALL Timothy O'Malley BE LIABLE FOR
# ANY SPECIAL, INDIRECT OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
# WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS,
# WHETHER IN AN ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS
# ACTION, ARISING OUT OF OR IN CONNECTION WITH THE USE OR
# PERFORMANCE OF THIS SOFTWARE.
#
####
#
# Id: Cookie.py,v 2.29 2000/08/23 05:28:49 timo Exp
# by Timothy O'Malley <timo@alum.mit.edu>
#
# Cookie.py is a Python module for the handling of HTTP
# cookies as a Python dictionary. See RFC 2109 for more
# information on cookies.
#
# The original idea to treat Cookies as a dictionary came from
# Dave Mitchell (davem@magnet.com) in 1995, when he released the
# first version of nscookie.py.
#
####
r"""
Here's a sample session to show how to use this module.
At the moment, this is the only documentation.
The Basics
----------
Importing is easy..
>>> import Cookie
Most of the time you start by creating a cookie. Cookies come in
three flavors, each with slightly different encoding semantics, but
more on that later.
>>> C = Cookie.SimpleCookie()
>>> C = Cookie.SerialCookie()
>>> C = Cookie.SmartCookie()
[Note: Long-time users of Cookie.py will remember using
Cookie.Cookie() to create an Cookie object. Although deprecated, it
is still supported by the code. See the Backward Compatibility notes
for more information.]
Once you've created your Cookie, you can add values just as if it were
a dictionary.
>>> C = Cookie.SmartCookie()
>>> C["fig"] = "newton"
>>> C["sugar"] = "wafer"
>>> C.output()
'Set-Cookie: fig=newton\r\nSet-Cookie: sugar=wafer'
Notice that the printable representation of a Cookie is the
appropriate format for a Set-Cookie: header. This is the
default behavior. You can change the header and printed
attributes by using the .output() function
>>> C = Cookie.SmartCookie()
>>> C["rocky"] = "road"
>>> C["rocky"]["path"] = "/cookie"
>>> print C.output(header="Cookie:")
Cookie: rocky=road; Path=/cookie
>>> print C.output(attrs=[], header="Cookie:")
Cookie: rocky=road
The load() method of a Cookie extracts cookies from a string. In a
CGI script, you would use this method to extract the cookies from the
HTTP_COOKIE environment variable.
>>> C = Cookie.SmartCookie()
>>> C.load("chips=ahoy; vienna=finger")
>>> C.output()
'Set-Cookie: chips=ahoy\r\nSet-Cookie: vienna=finger'
The load() method is darn-tootin smart about identifying cookies
within a string. Escaped quotation marks, nested semicolons, and other
such trickeries do not confuse it.
>>> C = Cookie.SmartCookie()
>>> C.load('keebler="E=everybody; L=\\"Loves\\"; fudge=\\012;";')
>>> print C
Set-Cookie: keebler="E=everybody; L=\"Loves\"; fudge=\012;"
Each element of the Cookie also supports all of the RFC 2109
Cookie attributes. Here's an example which sets the Path
attribute.
>>> C = Cookie.SmartCookie()
>>> C["oreo"] = "doublestuff"
>>> C["oreo"]["path"] = "/"
>>> print C
Set-Cookie: oreo=doublestuff; Path=/
Each dictionary element has a 'value' attribute, which gives you
back the value associated with the key.
>>> C = Cookie.SmartCookie()
>>> C["twix"] = "none for you"
>>> C["twix"].value
'none for you'
A Bit More Advanced
-------------------
As mentioned before, there are three different flavors of Cookie
objects, each with different encoding/decoding semantics. This
section briefly discusses the differences.
SimpleCookie
The SimpleCookie expects that all values should be standard strings.
Just to be sure, SimpleCookie invokes the str() builtin to convert
the value to a string, when the values are set dictionary-style.
>>> C = Cookie.SimpleCookie()
>>> C["number"] = 7
>>> C["string"] = "seven"
>>> C["number"].value
'7'
>>> C["string"].value
'seven'
>>> C.output()
'Set-Cookie: number=7\r\nSet-Cookie: string=seven'
SerialCookie
The SerialCookie expects that all values should be serialized using
cPickle (or pickle, if cPickle isn't available). As a result of
serializing, SerialCookie can save almost any Python object to a
value, and recover the exact same object when the cookie has been
returned. (SerialCookie can yield some strange-looking cookie
values, however.)
>>> C = Cookie.SerialCookie()
>>> C["number"] = 7
>>> C["string"] = "seven"
>>> C["number"].value
7
>>> C["string"].value
'seven'
>>> C.output()
'Set-Cookie: number="I7\\012."\r\nSet-Cookie: string="S\'seven\'\\012p1\\012."'
Be warned, however, if SerialCookie cannot de-serialize a value (because
it isn't a valid pickle'd object), IT WILL RAISE AN EXCEPTION.
SmartCookie
The SmartCookie combines aspects of each of the other two flavors.
When setting a value in a dictionary-fashion, the SmartCookie will
serialize (ala cPickle) the value *if and only if* it isn't a
Python string. String objects are *not* serialized. Similarly,
when the load() method parses out values, it attempts to de-serialize
the value. If it fails, then it fallsback to treating the value
as a string.
>>> C = Cookie.SmartCookie()
>>> C["number"] = 7
>>> C["string"] = "seven"
>>> C["number"].value
7
>>> C["string"].value
'seven'
>>> C.output()
'Set-Cookie: number="I7\\012."\r\nSet-Cookie: string=seven'
Backwards Compatibility
-----------------------
In order to keep compatibilty with earlier versions of Cookie.py,
it is still possible to use Cookie.Cookie() to create a Cookie. In
fact, this simply returns a SmartCookie.
>>> C = Cookie.Cookie()
>>> print C.__class__.__name__
SmartCookie
Finis.
""" #"
# ^
# |----helps out font-lock
#
# Import our required modules
#
import string
try:
from cPickle import dumps, loads
except ImportError:
from pickle import dumps, loads
import re, warnings
__all__ = ["CookieError","BaseCookie","SimpleCookie","SerialCookie",
"SmartCookie","Cookie"]
_nulljoin = ''.join
_semispacejoin = '; '.join
_spacejoin = ' '.join
#
# Define an exception visible to External modules
#
class CookieError(Exception):
pass
# These quoting routines conform to the RFC2109 specification, which in
# turn references the character definitions from RFC2068. They provide
# a two-way quoting algorithm. Any non-text character is translated
# into a 4 character sequence: a forward-slash followed by the
# three-digit octal equivalent of the character. Any '\' or '"' is
# quoted with a preceeding '\' slash.
#
# These are taken from RFC2068 and RFC2109.
# _LegalChars is the list of chars which don't require "'s
# _Translator hash-table for fast quoting
#
_LegalChars = string.ascii_letters + string.digits + "!#$%&'*+-.^_`|~"
_Translator = {
'\000' : '\\000', '\001' : '\\001', '\002' : '\\002',
'\003' : '\\003', '\004' : '\\004', '\005' : '\\005',
'\006' : '\\006', '\007' : '\\007', '\010' : '\\010',
'\011' : '\\011', '\012' : '\\012', '\013' : '\\013',
'\014' : '\\014', '\015' : '\\015', '\016' : '\\016',
'\017' : '\\017', '\020' : '\\020', '\021' : '\\021',
'\022' : '\\022', '\023' : '\\023', '\024' : '\\024',
'\025' : '\\025', '\026' : '\\026', '\027' : '\\027',
'\030' : '\\030', '\031' : '\\031', '\032' : '\\032',
'\033' : '\\033', '\034' : '\\034', '\035' : '\\035',
'\036' : '\\036', '\037' : '\\037',
'"' : '\\"', '\\' : '\\\\',
'\177' : '\\177', '\200' : '\\200', '\201' : '\\201',
'\202' : '\\202', '\203' : '\\203', '\204' : '\\204',
'\205' : '\\205', '\206' : '\\206', '\207' : '\\207',
'\210' : '\\210', '\211' : '\\211', '\212' : '\\212',
'\213' : '\\213', '\214' : '\\214', '\215' : '\\215',
'\216' : '\\216', '\217' : '\\217', '\220' : '\\220',
'\221' : '\\221', '\222' : '\\222', '\223' : '\\223',
'\224' : '\\224', '\225' : '\\225', '\226' : '\\226',
'\227' : '\\227', '\230' : '\\230', '\231' : '\\231',
'\232' : '\\232', '\233' : '\\233', '\234' : '\\234',
'\235' : '\\235', '\236' : '\\236', '\237' : '\\237',
'\240' : '\\240', '\241' : '\\241', '\242' : '\\242',
'\243' : '\\243', '\244' : '\\244', '\245' : '\\245',
'\246' : '\\246', '\247' : '\\247', '\250' : '\\250',
'\251' : '\\251', '\252' : '\\252', '\253' : '\\253',
'\254' : '\\254', '\255' : '\\255', '\256' : '\\256',
'\257' : '\\257', '\260' : '\\260', '\261' : '\\261',
'\262' : '\\262', '\263' : '\\263', '\264' : '\\264',
'\265' : '\\265', '\266' : '\\266', '\267' : '\\267',
'\270' : '\\270', '\271' : '\\271', '\272' : '\\272',
'\273' : '\\273', '\274' : '\\274', '\275' : '\\275',
'\276' : '\\276', '\277' : '\\277', '\300' : '\\300',
'\301' : '\\301', '\302' : '\\302', '\303' : '\\303',
'\304' : '\\304', '\305' : '\\305', '\306' : '\\306',
'\307' : '\\307', '\310' : '\\310', '\311' : '\\311',
'\312' : '\\312', '\313' : '\\313', '\314' : '\\314',
'\315' : '\\315', '\316' : '\\316', '\317' : '\\317',
'\320' : '\\320', '\321' : '\\321', '\322' : '\\322',
'\323' : '\\323', '\324' : '\\324', '\325' : '\\325',
'\326' : '\\326', '\327' : '\\327', '\330' : '\\330',
'\331' : '\\331', '\332' : '\\332', '\333' : '\\333',
'\334' : '\\334', '\335' : '\\335', '\336' : '\\336',
'\337' : '\\337', '\340' : '\\340', '\341' : '\\341',
'\342' : '\\342', '\343' : '\\343', '\344' : '\\344',
'\345' : '\\345', '\346' : '\\346', '\347' : '\\347',
'\350' : '\\350', '\351' : '\\351', '\352' : '\\352',
'\353' : '\\353', '\354' : '\\354', '\355' : '\\355',
'\356' : '\\356', '\357' : '\\357', '\360' : '\\360',
'\361' : '\\361', '\362' : '\\362', '\363' : '\\363',
'\364' : '\\364', '\365' : '\\365', '\366' : '\\366',
'\367' : '\\367', '\370' : '\\370', '\371' : '\\371',
'\372' : '\\372', '\373' : '\\373', '\374' : '\\374',
'\375' : '\\375', '\376' : '\\376', '\377' : '\\377'
}
_idmap = ''.join(chr(x) for x in xrange(256))
def _quote(str, LegalChars=_LegalChars,
idmap=_idmap, translate=string.translate):
#
# If the string does not need to be double-quoted,
# then just return the string. Otherwise, surround
# the string in doublequotes and precede quote (with a \)
# special characters.
#
if "" == translate(str, idmap, LegalChars):
return str
else:
return '"' + _nulljoin( map(_Translator.get, str, str) ) + '"'
# end _quote
_OctalPatt = re.compile(r"\\[0-3][0-7][0-7]")
_QuotePatt = re.compile(r"[\\].")
def _unquote(str):
# If there aren't any doublequotes,
# then there can't be any special characters. See RFC 2109.
if len(str) < 2:
return str
if str[0] != '"' or str[-1] != '"':
return str
# We have to assume that we must decode this string.
# Down to work.
# Remove the "s
str = str[1:-1]
# Check for special sequences. Examples:
# \012 --> \n
# \" --> "
#
i = 0
n = len(str)
res = []
while 0 <= i < n:
Omatch = _OctalPatt.search(str, i)
Qmatch = _QuotePatt.search(str, i)
if not Omatch and not Qmatch: # Neither matched
res.append(str[i:])
break
# else:
j = k = -1
if Omatch: j = Omatch.start(0)
if Qmatch: k = Qmatch.start(0)
if Qmatch and ( not Omatch or k < j ): # QuotePatt matched
res.append(str[i:k])
res.append(str[k+1])
i = k+2
else: # OctalPatt matched
res.append(str[i:j])
res.append( chr( int(str[j+1:j+4], 8) ) )
i = j+4
return _nulljoin(res)
# end _unquote
# The _getdate() routine is used to set the expiration time in
# the cookie's HTTP header. By default, _getdate() returns the
# current time in the appropriate "expires" format for a
# Set-Cookie header. The one optional argument is an offset from
# now, in seconds. For example, an offset of -3600 means "one hour ago".
# The offset may be a floating point number.
#
_weekdayname = ['Mon', 'Tue', 'Wed', 'Thu', 'Fri', 'Sat', 'Sun']
_monthname = [None,
'Jan', 'Feb', 'Mar', 'Apr', 'May', 'Jun',
'Jul', 'Aug', 'Sep', 'Oct', 'Nov', 'Dec']
def _getdate(future=0, weekdayname=_weekdayname, monthname=_monthname):
from time import gmtime, time
now = time()
year, month, day, hh, mm, ss, wd, y, z = gmtime(now + future)
return "%s, %02d-%3s-%4d %02d:%02d:%02d GMT" % \
(weekdayname[wd], day, monthname[month], year, hh, mm, ss)
#
# A class to hold ONE key,value pair.
# In a cookie, each such pair may have several attributes.
# so this class is used to keep the attributes associated
# with the appropriate key,value pair.
# This class also includes a coded_value attribute, which
# is used to hold the network representation of the
# value. This is most useful when Python objects are
# pickled for network transit.
#
class Morsel(dict):
# RFC 2109 lists these attributes as reserved:
# path comment domain
# max-age secure version
#
# For historical reasons, these attributes are also reserved:
# expires
#
# This is an extension from Microsoft:
# httponly
#
# This dictionary provides a mapping from the lowercase
# variant on the left to the appropriate traditional
# formatting on the right.
_reserved = { "expires" : "expires",
"path" : "Path",
"comment" : "Comment",
"domain" : "Domain",
"max-age" : "Max-Age",
"secure" : "secure",
"httponly" : "httponly",
"version" : "Version",
}
def __init__(self):
# Set defaults
self.key = self.value = self.coded_value = None
# Set default attributes
for K in self._reserved:
dict.__setitem__(self, K, "")
# end __init__
def __setitem__(self, K, V):
K = K.lower()
if not K in self._reserved:
raise CookieError("Invalid Attribute %s" % K)
dict.__setitem__(self, K, V)
# end __setitem__
def isReservedKey(self, K):
return K.lower() in self._reserved
# end isReservedKey
def set(self, key, val, coded_val,
LegalChars=_LegalChars,
idmap=_idmap, translate=string.translate):
# First we verify that the key isn't a reserved word
# Second we make sure it only contains legal characters
if key.lower() in self._reserved:
raise CookieError("Attempt to set a reserved key: %s" % key)
if "" != translate(key, idmap, LegalChars):
raise CookieError("Illegal key value: %s" % key)
# It's a good key, so save it.
self.key = key
self.value = val
self.coded_value = coded_val
# end set
def output(self, attrs=None, header = "Set-Cookie:"):
return "%s %s" % ( header, self.OutputString(attrs) )
__str__ = output
def __repr__(self):
return '<%s: %s=%s>' % (self.__class__.__name__,
self.key, repr(self.value) )
def js_output(self, attrs=None):
# Print javascript
return """
<script type="text/javascript">
<!-- begin hiding
document.cookie = \"%s\";
// end hiding -->
</script>
""" % ( self.OutputString(attrs).replace('"',r'\"'), )
# end js_output()
def OutputString(self, attrs=None):
# Build up our result
#
result = []
RA = result.append
# First, the key=value pair
RA("%s=%s" % (self.key, self.coded_value))
# Now add any defined attributes
if attrs is None:
attrs = self._reserved
items = self.items()
items.sort()
for K,V in items:
if V == "": continue
if K not in attrs: continue
if K == "expires" and type(V) == type(1):
RA("%s=%s" % (self._reserved[K], _getdate(V)))
elif K == "max-age" and type(V) == type(1):
RA("%s=%d" % (self._reserved[K], V))
elif K == "secure":
RA(str(self._reserved[K]))
elif K == "httponly":
RA(str(self._reserved[K]))
else:
RA("%s=%s" % (self._reserved[K], V))
# Return the result
return _semispacejoin(result)
# end OutputString
# end Morsel class
#
# Pattern for finding cookie
#
# This used to be strict parsing based on the RFC2109 and RFC2068
# specifications. I have since discovered that MSIE 3.0x doesn't
# follow the character rules outlined in those specs. As a
# result, the parsing rules here are less strict.
#
_LegalCharsPatt = r"[\w\d!#%&'~_`><@,:/\$\*\+\-\.\^\|\)\(\?\}\{\=]"
_CookiePattern = re.compile(
r"(?x)" # This is a Verbose pattern
r"(?P<key>" # Start of group 'key'
""+ _LegalCharsPatt +"+?" # Any word of at least one letter, nongreedy
r")" # End of group 'key'
r"\s*=\s*" # Equal Sign
r"(?P<val>" # Start of group 'val'
r'"(?:[^\\"]|\\.)*"' # Any doublequoted string
r"|" # or
r"\w{3},\s[\w\d-]{9,11}\s[\d:]{8}\sGMT" # Special case for "expires" attr
r"|" # or
""+ _LegalCharsPatt +"*" # Any word or empty string
r")" # End of group 'val'
r"\s*;?" # Probably ending in a semi-colon
)
# At long last, here is the cookie class.
# Using this class is almost just like using a dictionary.
# See this module's docstring for example usage.
#
class BaseCookie(dict):
# A container class for a set of Morsels
#
def value_decode(self, val):
"""real_value, coded_value = value_decode(STRING)
Called prior to setting a cookie's value from the network
representation. The VALUE is the value read from HTTP
header.
Override this function to modify the behavior of cookies.
"""
return val, val
# end value_encode
def value_encode(self, val):
"""real_value, coded_value = value_encode(VALUE)
Called prior to setting a cookie's value from the dictionary
representation. The VALUE is the value being assigned.
Override this function to modify the behavior of cookies.
"""
strval = str(val)
return strval, strval
# end value_encode
def __init__(self, input=None):
if input: self.load(input)
# end __init__
def __set(self, key, real_value, coded_value):
"""Private method for setting a cookie's value"""
M = self.get(key, Morsel())
M.set(key, real_value, coded_value)
dict.__setitem__(self, key, M)
# end __set
def __setitem__(self, key, value):
"""Dictionary style assignment."""
rval, cval = self.value_encode(value)
self.__set(key, rval, cval)
# end __setitem__
def output(self, attrs=None, header="Set-Cookie:", sep="\015\012"):
"""Return a string suitable for HTTP."""
result = []
items = self.items()
items.sort()
for K,V in items:
result.append( V.output(attrs, header) )
return sep.join(result)
# end output
__str__ = output
def __repr__(self):
L = []
items = self.items()
items.sort()
for K,V in items:
L.append( '%s=%s' % (K,repr(V.value) ) )
return '<%s: %s>' % (self.__class__.__name__, _spacejoin(L))
def js_output(self, attrs=None):
"""Return a string suitable for JavaScript."""
result = []
items = self.items()
items.sort()
for K,V in items:
result.append( V.js_output(attrs) )
return _nulljoin(result)
# end js_output
def load(self, rawdata):
"""Load cookies from a string (presumably HTTP_COOKIE) or
from a dictionary. Loading cookies from a dictionary 'd'
is equivalent to calling:
map(Cookie.__setitem__, d.keys(), d.values())
"""
if type(rawdata) == type(""):
self.__ParseString(rawdata)
else:
# self.update() wouldn't call our custom __setitem__
for k, v in rawdata.items():
self[k] = v
return
# end load()
def __ParseString(self, str, patt=_CookiePattern):
i = 0 # Our starting point
n = len(str) # Length of string
M = None # current morsel
while 0 <= i < n:
# Start looking for a cookie
match = patt.search(str, i)
if not match: break # No more cookies
K,V = match.group("key"), match.group("val")
i = match.end(0)
# Parse the key, value in case it's metainfo
if K[0] == "$":
# We ignore attributes which pertain to the cookie
# mechanism as a whole. See RFC 2109.
# (Does anyone care?)
if M:
M[ K[1:] ] = V
elif K.lower() in Morsel._reserved:
if M:
M[ K ] = _unquote(V)
else:
rval, cval = self.value_decode(V)
self.__set(K, rval, cval)
M = self[K]
# end __ParseString
# end BaseCookie class
class SimpleCookie(BaseCookie):
"""SimpleCookie
SimpleCookie supports strings as cookie values. When setting
the value using the dictionary assignment notation, SimpleCookie
calls the builtin str() to convert the value to a string. Values
received from HTTP are kept as strings.
"""
def value_decode(self, val):
return _unquote( val ), val
def value_encode(self, val):
strval = str(val)
return strval, _quote( strval )
# end SimpleCookie
class SerialCookie(BaseCookie):
"""SerialCookie
SerialCookie supports arbitrary objects as cookie values. All
values are serialized (using cPickle) before being sent to the
client. All incoming values are assumed to be valid Pickle
representations. IF AN INCOMING VALUE IS NOT IN A VALID PICKLE
FORMAT, THEN AN EXCEPTION WILL BE RAISED.
Note: Large cookie values add overhead because they must be
retransmitted on every HTTP transaction.
Note: HTTP has a 2k limit on the size of a cookie. This class
does not check for this limit, so be careful!!!
"""
def __init__(self, input=None):
warnings.warn("SerialCookie class is insecure; do not use it",
DeprecationWarning)
BaseCookie.__init__(self, input)
# end __init__
def value_decode(self, val):
# This could raise an exception!
return loads( _unquote(val) ), val
def value_encode(self, val):
return val, _quote( dumps(val) )
# end SerialCookie
class SmartCookie(BaseCookie):
"""SmartCookie
SmartCookie supports arbitrary objects as cookie values. If the
object is a string, then it is quoted. If the object is not a
string, however, then SmartCookie will use cPickle to serialize
the object into a string representation.
Note: Large cookie values add overhead because they must be
retransmitted on every HTTP transaction.
Note: HTTP has a 2k limit on the size of a cookie. This class
does not check for this limit, so be careful!!!
"""
def __init__(self, input=None):
warnings.warn("Cookie/SmartCookie class is insecure; do not use it",
DeprecationWarning)
BaseCookie.__init__(self, input)
# end __init__
def value_decode(self, val):
strval = _unquote(val)
try:
return loads(strval), val
except:
return strval, val
def value_encode(self, val):
if type(val) == type(""):
return val, _quote(val)
else:
return val, _quote( dumps(val) )
# end SmartCookie
###########################################################
# Backwards Compatibility: Don't break any existing code!
# We provide Cookie() as an alias for SmartCookie()
Cookie = SmartCookie
#
###########################################################
def _test():
import doctest, Cookie
return doctest.testmod(Cookie)
if __name__ == "__main__":
_test()
#Local Variables:
#tab-width: 4
#end:

View File

@@ -0,0 +1,279 @@
"""Self documenting XML-RPC Server.
This module can be used to create XML-RPC servers that
serve pydoc-style documentation in response to HTTP
GET requests. This documentation is dynamically generated
based on the functions and methods registered with the
server.
This module is built upon the pydoc and SimpleXMLRPCServer
modules.
"""
import pydoc
import inspect
import re
import sys
from SimpleXMLRPCServer import (SimpleXMLRPCServer,
SimpleXMLRPCRequestHandler,
CGIXMLRPCRequestHandler,
resolve_dotted_attribute)
class ServerHTMLDoc(pydoc.HTMLDoc):
"""Class used to generate pydoc HTML document for a server"""
def markup(self, text, escape=None, funcs={}, classes={}, methods={}):
"""Mark up some plain text, given a context of symbols to look for.
Each context dictionary maps object names to anchor names."""
escape = escape or self.escape
results = []
here = 0
# XXX Note that this regular expression does not allow for the
# hyperlinking of arbitrary strings being used as method
# names. Only methods with names consisting of word characters
# and '.'s are hyperlinked.
pattern = re.compile(r'\b((http|ftp)://\S+[\w/]|'
r'RFC[- ]?(\d+)|'
r'PEP[- ]?(\d+)|'
r'(self\.)?((?:\w|\.)+))\b')
while 1:
match = pattern.search(text, here)
if not match: break
start, end = match.span()
results.append(escape(text[here:start]))
all, scheme, rfc, pep, selfdot, name = match.groups()
if scheme:
url = escape(all).replace('"', '&quot;')
results.append('<a href="%s">%s</a>' % (url, url))
elif rfc:
url = 'http://www.rfc-editor.org/rfc/rfc%d.txt' % int(rfc)
results.append('<a href="%s">%s</a>' % (url, escape(all)))
elif pep:
url = 'http://www.python.org/dev/peps/pep-%04d/' % int(pep)
results.append('<a href="%s">%s</a>' % (url, escape(all)))
elif text[end:end+1] == '(':
results.append(self.namelink(name, methods, funcs, classes))
elif selfdot:
results.append('self.<strong>%s</strong>' % name)
else:
results.append(self.namelink(name, classes))
here = end
results.append(escape(text[here:]))
return ''.join(results)
def docroutine(self, object, name, mod=None,
funcs={}, classes={}, methods={}, cl=None):
"""Produce HTML documentation for a function or method object."""
anchor = (cl and cl.__name__ or '') + '-' + name
note = ''
title = '<a name="%s"><strong>%s</strong></a>' % (
self.escape(anchor), self.escape(name))
if inspect.ismethod(object):
args, varargs, varkw, defaults = inspect.getargspec(object.im_func)
# exclude the argument bound to the instance, it will be
# confusing to the non-Python user
argspec = inspect.formatargspec (
args[1:],
varargs,
varkw,
defaults,
formatvalue=self.formatvalue
)
elif inspect.isfunction(object):
args, varargs, varkw, defaults = inspect.getargspec(object)
argspec = inspect.formatargspec(
args, varargs, varkw, defaults, formatvalue=self.formatvalue)
else:
argspec = '(...)'
if isinstance(object, tuple):
argspec = object[0] or argspec
docstring = object[1] or ""
else:
docstring = pydoc.getdoc(object)
decl = title + argspec + (note and self.grey(
'<font face="helvetica, arial">%s</font>' % note))
doc = self.markup(
docstring, self.preformat, funcs, classes, methods)
doc = doc and '<dd><tt>%s</tt></dd>' % doc
return '<dl><dt>%s</dt>%s</dl>\n' % (decl, doc)
def docserver(self, server_name, package_documentation, methods):
"""Produce HTML documentation for an XML-RPC server."""
fdict = {}
for key, value in methods.items():
fdict[key] = '#-' + key
fdict[value] = fdict[key]
server_name = self.escape(server_name)
head = '<big><big><strong>%s</strong></big></big>' % server_name
result = self.heading(head, '#ffffff', '#7799ee')
doc = self.markup(package_documentation, self.preformat, fdict)
doc = doc and '<tt>%s</tt>' % doc
result = result + '<p>%s</p>\n' % doc
contents = []
method_items = sorted(methods.items())
for key, value in method_items:
contents.append(self.docroutine(value, key, funcs=fdict))
result = result + self.bigsection(
'Methods', '#ffffff', '#eeaa77', pydoc.join(contents))
return result
class XMLRPCDocGenerator:
"""Generates documentation for an XML-RPC server.
This class is designed as mix-in and should not
be constructed directly.
"""
def __init__(self):
# setup variables used for HTML documentation
self.server_name = 'XML-RPC Server Documentation'
self.server_documentation = \
"This server exports the following methods through the XML-RPC "\
"protocol."
self.server_title = 'XML-RPC Server Documentation'
def set_server_title(self, server_title):
"""Set the HTML title of the generated server documentation"""
self.server_title = server_title
def set_server_name(self, server_name):
"""Set the name of the generated HTML server documentation"""
self.server_name = server_name
def set_server_documentation(self, server_documentation):
"""Set the documentation string for the entire server."""
self.server_documentation = server_documentation
def generate_html_documentation(self):
"""generate_html_documentation() => html documentation for the server
Generates HTML documentation for the server using introspection for
installed functions and instances that do not implement the
_dispatch method. Alternatively, instances can choose to implement
the _get_method_argstring(method_name) method to provide the
argument string used in the documentation and the
_methodHelp(method_name) method to provide the help text used
in the documentation."""
methods = {}
for method_name in self.system_listMethods():
if method_name in self.funcs:
method = self.funcs[method_name]
elif self.instance is not None:
method_info = [None, None] # argspec, documentation
if hasattr(self.instance, '_get_method_argstring'):
method_info[0] = self.instance._get_method_argstring(method_name)
if hasattr(self.instance, '_methodHelp'):
method_info[1] = self.instance._methodHelp(method_name)
method_info = tuple(method_info)
if method_info != (None, None):
method = method_info
elif not hasattr(self.instance, '_dispatch'):
try:
method = resolve_dotted_attribute(
self.instance,
method_name
)
except AttributeError:
method = method_info
else:
method = method_info
else:
assert 0, "Could not find method in self.functions and no "\
"instance installed"
methods[method_name] = method
documenter = ServerHTMLDoc()
documentation = documenter.docserver(
self.server_name,
self.server_documentation,
methods
)
return documenter.page(self.server_title, documentation)
class DocXMLRPCRequestHandler(SimpleXMLRPCRequestHandler):
"""XML-RPC and documentation request handler class.
Handles all HTTP POST requests and attempts to decode them as
XML-RPC requests.
Handles all HTTP GET requests and interprets them as requests
for documentation.
"""
def do_GET(self):
"""Handles the HTTP GET request.
Interpret all HTTP GET requests as requests for server
documentation.
"""
# Check that the path is legal
if not self.is_rpc_path_valid():
self.report_404()
return
response = self.server.generate_html_documentation()
self.send_response(200)
self.send_header("Content-type", "text/html")
self.send_header("Content-length", str(len(response)))
self.end_headers()
self.wfile.write(response)
class DocXMLRPCServer( SimpleXMLRPCServer,
XMLRPCDocGenerator):
"""XML-RPC and HTML documentation server.
Adds the ability to serve server documentation to the capabilities
of SimpleXMLRPCServer.
"""
def __init__(self, addr, requestHandler=DocXMLRPCRequestHandler,
logRequests=1, allow_none=False, encoding=None,
bind_and_activate=True):
SimpleXMLRPCServer.__init__(self, addr, requestHandler, logRequests,
allow_none, encoding, bind_and_activate)
XMLRPCDocGenerator.__init__(self)
class DocCGIXMLRPCRequestHandler( CGIXMLRPCRequestHandler,
XMLRPCDocGenerator):
"""Handler for XML-RPC data and documentation requests passed through
CGI"""
def handle_get(self):
"""Handles the HTTP GET request.
Interpret all HTTP GET requests as requests for server
documentation.
"""
response = self.generate_html_documentation()
print 'Content-Type: text/html'
print 'Content-Length: %d' % len(response)
print
sys.stdout.write(response)
def __init__(self):
CGIXMLRPCRequestHandler.__init__(self)
XMLRPCDocGenerator.__init__(self)

View File

@@ -0,0 +1,390 @@
"""A parser for HTML and XHTML."""
# This file is based on sgmllib.py, but the API is slightly different.
# XXX There should be a way to distinguish between PCDATA (parsed
# character data -- the normal case), RCDATA (replaceable character
# data -- only char and entity references and end tags are special)
# and CDATA (character data -- only end tags are special).
import markupbase
import re
# Regular expressions used for parsing
interesting_normal = re.compile('[&<]')
interesting_cdata = re.compile(r'<(/|\Z)')
incomplete = re.compile('&[a-zA-Z#]')
entityref = re.compile('&([a-zA-Z][-.a-zA-Z0-9]*)[^a-zA-Z0-9]')
charref = re.compile('&#(?:[0-9]+|[xX][0-9a-fA-F]+)[^0-9a-fA-F]')
starttagopen = re.compile('<[a-zA-Z]')
piclose = re.compile('>')
commentclose = re.compile(r'--\s*>')
tagfind = re.compile('[a-zA-Z][-.a-zA-Z0-9:_]*')
attrfind = re.compile(
r'\s*([a-zA-Z_][-.:a-zA-Z_0-9]*)(\s*=\s*'
r'(\'[^\']*\'|"[^"]*"|[-a-zA-Z0-9./,:;+*%?!&$\(\)_#=~@]*))?')
locatestarttagend = re.compile(r"""
<[a-zA-Z][-.a-zA-Z0-9:_]* # tag name
(?:\s+ # whitespace before attribute name
(?:[a-zA-Z_][-.:a-zA-Z0-9_]* # attribute name
(?:\s*=\s* # value indicator
(?:'[^']*' # LITA-enclosed value
|\"[^\"]*\" # LIT-enclosed value
|[^'\">\s]+ # bare value
)
)?
)
)*
\s* # trailing whitespace
""", re.VERBOSE)
endendtag = re.compile('>')
endtagfind = re.compile('</\s*([a-zA-Z][-.a-zA-Z0-9:_]*)\s*>')
class HTMLParseError(Exception):
"""Exception raised for all parse errors."""
def __init__(self, msg, position=(None, None)):
assert msg
self.msg = msg
self.lineno = position[0]
self.offset = position[1]
def __str__(self):
result = self.msg
if self.lineno is not None:
result = result + ", at line %d" % self.lineno
if self.offset is not None:
result = result + ", column %d" % (self.offset + 1)
return result
class HTMLParser(markupbase.ParserBase):
"""Find tags and other markup and call handler functions.
Usage:
p = HTMLParser()
p.feed(data)
...
p.close()
Start tags are handled by calling self.handle_starttag() or
self.handle_startendtag(); end tags by self.handle_endtag(). The
data between tags is passed from the parser to the derived class
by calling self.handle_data() with the data as argument (the data
may be split up in arbitrary chunks). Entity references are
passed by calling self.handle_entityref() with the entity
reference as the argument. Numeric character references are
passed to self.handle_charref() with the string containing the
reference as the argument.
"""
CDATA_CONTENT_ELEMENTS = ("script", "style")
def __init__(self):
"""Initialize and reset this instance."""
self.reset()
def reset(self):
"""Reset this instance. Loses all unprocessed data."""
self.rawdata = ''
self.lasttag = '???'
self.interesting = interesting_normal
markupbase.ParserBase.reset(self)
def feed(self, data):
"""Feed data to the parser.
Call this as often as you want, with as little or as much text
as you want (may include '\n').
"""
self.rawdata = self.rawdata + data
self.goahead(0)
def close(self):
"""Handle any buffered data."""
self.goahead(1)
def error(self, message):
raise HTMLParseError(message, self.getpos())
__starttag_text = None
def get_starttag_text(self):
"""Return full source of start tag: '<...>'."""
return self.__starttag_text
def set_cdata_mode(self):
self.interesting = interesting_cdata
def clear_cdata_mode(self):
self.interesting = interesting_normal
# Internal -- handle data as far as reasonable. May leave state
# and data to be processed by a subsequent call. If 'end' is
# true, force handling all data as if followed by EOF marker.
def goahead(self, end):
rawdata = self.rawdata
i = 0
n = len(rawdata)
while i < n:
match = self.interesting.search(rawdata, i) # < or &
if match:
j = match.start()
else:
j = n
if i < j: self.handle_data(rawdata[i:j])
i = self.updatepos(i, j)
if i == n: break
startswith = rawdata.startswith
if startswith('<', i):
if starttagopen.match(rawdata, i): # < + letter
k = self.parse_starttag(i)
elif startswith("</", i):
k = self.parse_endtag(i)
elif startswith("<!--", i):
k = self.parse_comment(i)
elif startswith("<?", i):
k = self.parse_pi(i)
elif startswith("<!", i):
k = self.parse_declaration(i)
elif (i + 1) < n:
self.handle_data("<")
k = i + 1
else:
break
if k < 0:
if end:
self.error("EOF in middle of construct")
break
i = self.updatepos(i, k)
elif startswith("&#", i):
match = charref.match(rawdata, i)
if match:
name = match.group()[2:-1]
self.handle_charref(name)
k = match.end()
if not startswith(';', k-1):
k = k - 1
i = self.updatepos(i, k)
continue
else:
if ";" in rawdata[i:]: #bail by consuming &#
self.handle_data(rawdata[0:2])
i = self.updatepos(i, 2)
break
elif startswith('&', i):
match = entityref.match(rawdata, i)
if match:
name = match.group(1)
self.handle_entityref(name)
k = match.end()
if not startswith(';', k-1):
k = k - 1
i = self.updatepos(i, k)
continue
match = incomplete.match(rawdata, i)
if match:
# match.group() will contain at least 2 chars
if end and match.group() == rawdata[i:]:
self.error("EOF in middle of entity or char ref")
# incomplete
break
elif (i + 1) < n:
# not the end of the buffer, and can't be confused
# with some other construct
self.handle_data("&")
i = self.updatepos(i, i + 1)
else:
break
else:
assert 0, "interesting.search() lied"
# end while
if end and i < n:
self.handle_data(rawdata[i:n])
i = self.updatepos(i, n)
self.rawdata = rawdata[i:]
# Internal -- parse processing instr, return end or -1 if not terminated
def parse_pi(self, i):
rawdata = self.rawdata
assert rawdata[i:i+2] == '<?', 'unexpected call to parse_pi()'
match = piclose.search(rawdata, i+2) # >
if not match:
return -1
j = match.start()
self.handle_pi(rawdata[i+2: j])
j = match.end()
return j
# Internal -- handle starttag, return end or -1 if not terminated
def parse_starttag(self, i):
self.__starttag_text = None
endpos = self.check_for_whole_start_tag(i)
if endpos < 0:
return endpos
rawdata = self.rawdata
self.__starttag_text = rawdata[i:endpos]
# Now parse the data between i+1 and j into a tag and attrs
attrs = []
match = tagfind.match(rawdata, i+1)
assert match, 'unexpected call to parse_starttag()'
k = match.end()
self.lasttag = tag = rawdata[i+1:k].lower()
while k < endpos:
m = attrfind.match(rawdata, k)
if not m:
break
attrname, rest, attrvalue = m.group(1, 2, 3)
if not rest:
attrvalue = None
elif attrvalue[:1] == '\'' == attrvalue[-1:] or \
attrvalue[:1] == '"' == attrvalue[-1:]:
attrvalue = attrvalue[1:-1]
attrvalue = self.unescape(attrvalue)
attrs.append((attrname.lower(), attrvalue))
k = m.end()
end = rawdata[k:endpos].strip()
if end not in (">", "/>"):
lineno, offset = self.getpos()
if "\n" in self.__starttag_text:
lineno = lineno + self.__starttag_text.count("\n")
offset = len(self.__starttag_text) \
- self.__starttag_text.rfind("\n")
else:
offset = offset + len(self.__starttag_text)
self.error("junk characters in start tag: %r"
% (rawdata[k:endpos][:20],))
if end.endswith('/>'):
# XHTML-style empty tag: <span attr="value" />
self.handle_startendtag(tag, attrs)
else:
self.handle_starttag(tag, attrs)
if tag in self.CDATA_CONTENT_ELEMENTS:
self.set_cdata_mode()
return endpos
# Internal -- check to see if we have a complete starttag; return end
# or -1 if incomplete.
def check_for_whole_start_tag(self, i):
rawdata = self.rawdata
m = locatestarttagend.match(rawdata, i)
if m:
j = m.end()
next = rawdata[j:j+1]
if next == ">":
return j + 1
if next == "/":
if rawdata.startswith("/>", j):
return j + 2
if rawdata.startswith("/", j):
# buffer boundary
return -1
# else bogus input
self.updatepos(i, j + 1)
self.error("malformed empty start tag")
if next == "":
# end of input
return -1
if next in ("abcdefghijklmnopqrstuvwxyz=/"
"ABCDEFGHIJKLMNOPQRSTUVWXYZ"):
# end of input in or before attribute value, or we have the
# '/' from a '/>' ending
return -1
self.updatepos(i, j)
self.error("malformed start tag")
raise AssertionError("we should not get here!")
# Internal -- parse endtag, return end or -1 if incomplete
def parse_endtag(self, i):
rawdata = self.rawdata
assert rawdata[i:i+2] == "</", "unexpected call to parse_endtag"
match = endendtag.search(rawdata, i+1) # >
if not match:
return -1
j = match.end()
match = endtagfind.match(rawdata, i) # </ + tag + >
if not match:
self.error("bad end tag: %r" % (rawdata[i:j],))
tag = match.group(1)
self.handle_endtag(tag.lower())
self.clear_cdata_mode()
return j
# Overridable -- finish processing of start+end tag: <tag.../>
def handle_startendtag(self, tag, attrs):
self.handle_starttag(tag, attrs)
self.handle_endtag(tag)
# Overridable -- handle start tag
def handle_starttag(self, tag, attrs):
pass
# Overridable -- handle end tag
def handle_endtag(self, tag):
pass
# Overridable -- handle character reference
def handle_charref(self, name):
pass
# Overridable -- handle entity reference
def handle_entityref(self, name):
pass
# Overridable -- handle data
def handle_data(self, data):
pass
# Overridable -- handle comment
def handle_comment(self, data):
pass
# Overridable -- handle declaration
def handle_decl(self, decl):
pass
# Overridable -- handle processing instruction
def handle_pi(self, data):
pass
def unknown_decl(self, data):
self.error("unknown declaration: %r" % (data,))
# Internal -- helper to remove special character quoting
entitydefs = None
def unescape(self, s):
if '&' not in s:
return s
def replaceEntities(s):
s = s.groups()[0]
if s[0] == "#":
s = s[1:]
if s[0] in ['x','X']:
c = int(s[1:], 16)
else:
c = int(s)
return unichr(c)
else:
# Cannot use name2codepoint directly, because HTMLParser supports apos,
# which is not part of HTML 4
import htmlentitydefs
if HTMLParser.entitydefs is None:
entitydefs = HTMLParser.entitydefs = {'apos':u"'"}
for k, v in htmlentitydefs.name2codepoint.iteritems():
entitydefs[k] = unichr(v)
try:
return self.entitydefs[s]
except KeyError:
return '&'+s+';'
return re.sub(r"&(#?[xX]?(?:[0-9a-fA-F]+|\w{1,8}));", replaceEntities, s)

View File

@@ -0,0 +1,186 @@
"""Generic MIME writer.
This module defines the class MimeWriter. The MimeWriter class implements
a basic formatter for creating MIME multi-part files. It doesn't seek around
the output file nor does it use large amounts of buffer space. You must write
the parts out in the order that they should occur in the final file.
MimeWriter does buffer the headers you add, allowing you to rearrange their
order.
"""
import mimetools
__all__ = ["MimeWriter"]
import warnings
warnings.warn("the MimeWriter module is deprecated; use the email package instead",
DeprecationWarning, 2)
class MimeWriter:
"""Generic MIME writer.
Methods:
__init__()
addheader()
flushheaders()
startbody()
startmultipartbody()
nextpart()
lastpart()
A MIME writer is much more primitive than a MIME parser. It
doesn't seek around on the output file, and it doesn't use large
amounts of buffer space, so you have to write the parts in the
order they should occur on the output file. It does buffer the
headers you add, allowing you to rearrange their order.
General usage is:
f = <open the output file>
w = MimeWriter(f)
...call w.addheader(key, value) 0 or more times...
followed by either:
f = w.startbody(content_type)
...call f.write(data) for body data...
or:
w.startmultipartbody(subtype)
for each part:
subwriter = w.nextpart()
...use the subwriter's methods to create the subpart...
w.lastpart()
The subwriter is another MimeWriter instance, and should be
treated in the same way as the toplevel MimeWriter. This way,
writing recursive body parts is easy.
Warning: don't forget to call lastpart()!
XXX There should be more state so calls made in the wrong order
are detected.
Some special cases:
- startbody() just returns the file passed to the constructor;
but don't use this knowledge, as it may be changed.
- startmultipartbody() actually returns a file as well;
this can be used to write the initial 'if you can read this your
mailer is not MIME-aware' message.
- If you call flushheaders(), the headers accumulated so far are
written out (and forgotten); this is useful if you don't need a
body part at all, e.g. for a subpart of type message/rfc822
that's (mis)used to store some header-like information.
- Passing a keyword argument 'prefix=<flag>' to addheader(),
start*body() affects where the header is inserted; 0 means
append at the end, 1 means insert at the start; default is
append for addheader(), but insert for start*body(), which use
it to determine where the Content-Type header goes.
"""
def __init__(self, fp):
self._fp = fp
self._headers = []
def addheader(self, key, value, prefix=0):
"""Add a header line to the MIME message.
The key is the name of the header, where the value obviously provides
the value of the header. The optional argument prefix determines
where the header is inserted; 0 means append at the end, 1 means
insert at the start. The default is to append.
"""
lines = value.split("\n")
while lines and not lines[-1]: del lines[-1]
while lines and not lines[0]: del lines[0]
for i in range(1, len(lines)):
lines[i] = " " + lines[i].strip()
value = "\n".join(lines) + "\n"
line = key + ": " + value
if prefix:
self._headers.insert(0, line)
else:
self._headers.append(line)
def flushheaders(self):
"""Writes out and forgets all headers accumulated so far.
This is useful if you don't need a body part at all; for example,
for a subpart of type message/rfc822 that's (mis)used to store some
header-like information.
"""
self._fp.writelines(self._headers)
self._headers = []
def startbody(self, ctype, plist=[], prefix=1):
"""Returns a file-like object for writing the body of the message.
The content-type is set to the provided ctype, and the optional
parameter, plist, provides additional parameters for the
content-type declaration. The optional argument prefix determines
where the header is inserted; 0 means append at the end, 1 means
insert at the start. The default is to insert at the start.
"""
for name, value in plist:
ctype = ctype + ';\n %s=\"%s\"' % (name, value)
self.addheader("Content-Type", ctype, prefix=prefix)
self.flushheaders()
self._fp.write("\n")
return self._fp
def startmultipartbody(self, subtype, boundary=None, plist=[], prefix=1):
"""Returns a file-like object for writing the body of the message.
Additionally, this method initializes the multi-part code, where the
subtype parameter provides the multipart subtype, the boundary
parameter may provide a user-defined boundary specification, and the
plist parameter provides optional parameters for the subtype. The
optional argument, prefix, determines where the header is inserted;
0 means append at the end, 1 means insert at the start. The default
is to insert at the start. Subparts should be created using the
nextpart() method.
"""
self._boundary = boundary or mimetools.choose_boundary()
return self.startbody("multipart/" + subtype,
[("boundary", self._boundary)] + plist,
prefix=prefix)
def nextpart(self):
"""Returns a new instance of MimeWriter which represents an
individual part in a multipart message.
This may be used to write the part as well as used for creating
recursively complex multipart messages. The message must first be
initialized with the startmultipartbody() method before using the
nextpart() method.
"""
self._fp.write("\n--" + self._boundary + "\n")
return self.__class__(self._fp)
def lastpart(self):
"""This is used to designate the last part of a multipart message.
It should always be used when writing multipart messages.
"""
self._fp.write("\n--" + self._boundary + "--\n")
if __name__ == '__main__':
import test.test_MimeWriter

244
src/python/stdlib/Queue.py Normal file
View File

@@ -0,0 +1,244 @@
"""A multi-producer, multi-consumer queue."""
from time import time as _time
try:
import threading as _threading
except ImportError:
import dummy_threading as _threading
from collections import deque
import heapq
__all__ = ['Empty', 'Full', 'Queue', 'PriorityQueue', 'LifoQueue']
class Empty(Exception):
"Exception raised by Queue.get(block=0)/get_nowait()."
pass
class Full(Exception):
"Exception raised by Queue.put(block=0)/put_nowait()."
pass
class Queue:
"""Create a queue object with a given maximum size.
If maxsize is <= 0, the queue size is infinite.
"""
def __init__(self, maxsize=0):
self.maxsize = maxsize
self._init(maxsize)
# mutex must be held whenever the queue is mutating. All methods
# that acquire mutex must release it before returning. mutex
# is shared between the three conditions, so acquiring and
# releasing the conditions also acquires and releases mutex.
self.mutex = _threading.Lock()
# Notify not_empty whenever an item is added to the queue; a
# thread waiting to get is notified then.
self.not_empty = _threading.Condition(self.mutex)
# Notify not_full whenever an item is removed from the queue;
# a thread waiting to put is notified then.
self.not_full = _threading.Condition(self.mutex)
# Notify all_tasks_done whenever the number of unfinished tasks
# drops to zero; thread waiting to join() is notified to resume
self.all_tasks_done = _threading.Condition(self.mutex)
self.unfinished_tasks = 0
def task_done(self):
"""Indicate that a formerly enqueued task is complete.
Used by Queue consumer threads. For each get() used to fetch a task,
a subsequent call to task_done() tells the queue that the processing
on the task is complete.
If a join() is currently blocking, it will resume when all items
have been processed (meaning that a task_done() call was received
for every item that had been put() into the queue).
Raises a ValueError if called more times than there were items
placed in the queue.
"""
self.all_tasks_done.acquire()
try:
unfinished = self.unfinished_tasks - 1
if unfinished <= 0:
if unfinished < 0:
raise ValueError('task_done() called too many times')
self.all_tasks_done.notify_all()
self.unfinished_tasks = unfinished
finally:
self.all_tasks_done.release()
def join(self):
"""Blocks until all items in the Queue have been gotten and processed.
The count of unfinished tasks goes up whenever an item is added to the
queue. The count goes down whenever a consumer thread calls task_done()
to indicate the item was retrieved and all work on it is complete.
When the count of unfinished tasks drops to zero, join() unblocks.
"""
self.all_tasks_done.acquire()
try:
while self.unfinished_tasks:
self.all_tasks_done.wait()
finally:
self.all_tasks_done.release()
def qsize(self):
"""Return the approximate size of the queue (not reliable!)."""
self.mutex.acquire()
n = self._qsize()
self.mutex.release()
return n
def empty(self):
"""Return True if the queue is empty, False otherwise (not reliable!)."""
self.mutex.acquire()
n = not self._qsize()
self.mutex.release()
return n
def full(self):
"""Return True if the queue is full, False otherwise (not reliable!)."""
self.mutex.acquire()
n = 0 < self.maxsize == self._qsize()
self.mutex.release()
return n
def put(self, item, block=True, timeout=None):
"""Put an item into the queue.
If optional args 'block' is true and 'timeout' is None (the default),
block if necessary until a free slot is available. If 'timeout' is
a positive number, it blocks at most 'timeout' seconds and raises
the Full exception if no free slot was available within that time.
Otherwise ('block' is false), put an item on the queue if a free slot
is immediately available, else raise the Full exception ('timeout'
is ignored in that case).
"""
self.not_full.acquire()
try:
if self.maxsize > 0:
if not block:
if self._qsize() == self.maxsize:
raise Full
elif timeout is None:
while self._qsize() == self.maxsize:
self.not_full.wait()
elif timeout < 0:
raise ValueError("'timeout' must be a positive number")
else:
endtime = _time() + timeout
while self._qsize() == self.maxsize:
remaining = endtime - _time()
if remaining <= 0.0:
raise Full
self.not_full.wait(remaining)
self._put(item)
self.unfinished_tasks += 1
self.not_empty.notify()
finally:
self.not_full.release()
def put_nowait(self, item):
"""Put an item into the queue without blocking.
Only enqueue the item if a free slot is immediately available.
Otherwise raise the Full exception.
"""
return self.put(item, False)
def get(self, block=True, timeout=None):
"""Remove and return an item from the queue.
If optional args 'block' is true and 'timeout' is None (the default),
block if necessary until an item is available. If 'timeout' is
a positive number, it blocks at most 'timeout' seconds and raises
the Empty exception if no item was available within that time.
Otherwise ('block' is false), return an item if one is immediately
available, else raise the Empty exception ('timeout' is ignored
in that case).
"""
self.not_empty.acquire()
try:
if not block:
if not self._qsize():
raise Empty
elif timeout is None:
while not self._qsize():
self.not_empty.wait()
elif timeout < 0:
raise ValueError("'timeout' must be a positive number")
else:
endtime = _time() + timeout
while not self._qsize():
remaining = endtime - _time()
if remaining <= 0.0:
raise Empty
self.not_empty.wait(remaining)
item = self._get()
self.not_full.notify()
return item
finally:
self.not_empty.release()
def get_nowait(self):
"""Remove and return an item from the queue without blocking.
Only get an item if one is immediately available. Otherwise
raise the Empty exception.
"""
return self.get(False)
# Override these methods to implement other queue organizations
# (e.g. stack or priority queue).
# These will only be called with appropriate locks held
# Initialize the queue representation
def _init(self, maxsize):
self.queue = deque()
def _qsize(self, len=len):
return len(self.queue)
# Put a new item in the queue
def _put(self, item):
self.queue.append(item)
# Get an item from the queue
def _get(self):
return self.queue.popleft()
class PriorityQueue(Queue):
'''Variant of Queue that retrieves open entries in priority order (lowest first).
Entries are typically tuples of the form: (priority number, data).
'''
def _init(self, maxsize):
self.queue = []
def _qsize(self, len=len):
return len(self.queue)
def _put(self, item, heappush=heapq.heappush):
heappush(self.queue, item)
def _get(self, heappop=heapq.heappop):
return heappop(self.queue)
class LifoQueue(Queue):
'''Variant of Queue that retrieves most recently added entries first.'''
def _init(self, maxsize):
self.queue = []
def _qsize(self, len=len):
return len(self.queue)
def _put(self, item):
self.queue.append(item)
def _get(self):
return self.queue.pop()

View File

@@ -0,0 +1,218 @@
"""Simple HTTP Server.
This module builds on BaseHTTPServer by implementing the standard GET
and HEAD requests in a fairly straightforward manner.
"""
__version__ = "0.6"
__all__ = ["SimpleHTTPRequestHandler"]
import os
import posixpath
import BaseHTTPServer
import urllib
import cgi
import shutil
import mimetypes
try:
from cStringIO import StringIO
except ImportError:
from StringIO import StringIO
class SimpleHTTPRequestHandler(BaseHTTPServer.BaseHTTPRequestHandler):
"""Simple HTTP request handler with GET and HEAD commands.
This serves files from the current directory and any of its
subdirectories. The MIME type for files is determined by
calling the .guess_type() method.
The GET and HEAD requests are identical except that the HEAD
request omits the actual contents of the file.
"""
server_version = "SimpleHTTP/" + __version__
def do_GET(self):
"""Serve a GET request."""
f = self.send_head()
if f:
self.copyfile(f, self.wfile)
f.close()
def do_HEAD(self):
"""Serve a HEAD request."""
f = self.send_head()
if f:
f.close()
def send_head(self):
"""Common code for GET and HEAD commands.
This sends the response code and MIME headers.
Return value is either a file object (which has to be copied
to the outputfile by the caller unless the command was HEAD,
and must be closed by the caller under all circumstances), or
None, in which case the caller has nothing further to do.
"""
path = self.translate_path(self.path)
f = None
if os.path.isdir(path):
if not self.path.endswith('/'):
# redirect browser - doing basically what apache does
self.send_response(301)
self.send_header("Location", self.path + "/")
self.end_headers()
return None
for index in "index.html", "index.htm":
index = os.path.join(path, index)
if os.path.exists(index):
path = index
break
else:
return self.list_directory(path)
ctype = self.guess_type(path)
try:
# Always read in binary mode. Opening files in text mode may cause
# newline translations, making the actual size of the content
# transmitted *less* than the content-length!
f = open(path, 'rb')
except IOError:
self.send_error(404, "File not found")
return None
self.send_response(200)
self.send_header("Content-type", ctype)
fs = os.fstat(f.fileno())
self.send_header("Content-Length", str(fs[6]))
self.send_header("Last-Modified", self.date_time_string(fs.st_mtime))
self.end_headers()
return f
def list_directory(self, path):
"""Helper to produce a directory listing (absent index.html).
Return value is either a file object, or None (indicating an
error). In either case, the headers are sent, making the
interface the same as for send_head().
"""
try:
list = os.listdir(path)
except os.error:
self.send_error(404, "No permission to list directory")
return None
list.sort(key=lambda a: a.lower())
f = StringIO()
displaypath = cgi.escape(urllib.unquote(self.path))
f.write('<!DOCTYPE html PUBLIC "-//W3C//DTD HTML 3.2 Final//EN">')
f.write("<html>\n<title>Directory listing for %s</title>\n" % displaypath)
f.write("<body>\n<h2>Directory listing for %s</h2>\n" % displaypath)
f.write("<hr>\n<ul>\n")
for name in list:
fullname = os.path.join(path, name)
displayname = linkname = name
# Append / for directories or @ for symbolic links
if os.path.isdir(fullname):
displayname = name + "/"
linkname = name + "/"
if os.path.islink(fullname):
displayname = name + "@"
# Note: a link to a directory displays with @ and links with /
f.write('<li><a href="%s">%s</a>\n'
% (urllib.quote(linkname), cgi.escape(displayname)))
f.write("</ul>\n<hr>\n</body>\n</html>\n")
length = f.tell()
f.seek(0)
self.send_response(200)
self.send_header("Content-type", "text/html")
self.send_header("Content-Length", str(length))
self.end_headers()
return f
def translate_path(self, path):
"""Translate a /-separated PATH to the local filename syntax.
Components that mean special things to the local file system
(e.g. drive or directory names) are ignored. (XXX They should
probably be diagnosed.)
"""
# abandon query parameters
path = path.split('?',1)[0]
path = path.split('#',1)[0]
path = posixpath.normpath(urllib.unquote(path))
words = path.split('/')
words = filter(None, words)
path = os.getcwd()
for word in words:
drive, word = os.path.splitdrive(word)
head, word = os.path.split(word)
if word in (os.curdir, os.pardir): continue
path = os.path.join(path, word)
return path
def copyfile(self, source, outputfile):
"""Copy all data between two file objects.
The SOURCE argument is a file object open for reading
(or anything with a read() method) and the DESTINATION
argument is a file object open for writing (or
anything with a write() method).
The only reason for overriding this would be to change
the block size or perhaps to replace newlines by CRLF
-- note however that this the default server uses this
to copy binary data as well.
"""
shutil.copyfileobj(source, outputfile)
def guess_type(self, path):
"""Guess the type of a file.
Argument is a PATH (a filename).
Return value is a string of the form type/subtype,
usable for a MIME Content-type header.
The default implementation looks the file's extension
up in the table self.extensions_map, using application/octet-stream
as a default; however it would be permissible (if
slow) to look inside the data to make a better guess.
"""
base, ext = posixpath.splitext(path)
if ext in self.extensions_map:
return self.extensions_map[ext]
ext = ext.lower()
if ext in self.extensions_map:
return self.extensions_map[ext]
else:
return self.extensions_map['']
if not mimetypes.inited:
mimetypes.init() # try to read system mime.types
extensions_map = mimetypes.types_map.copy()
extensions_map.update({
'': 'application/octet-stream', # Default
'.py': 'text/plain',
'.c': 'text/plain',
'.h': 'text/plain',
})
def test(HandlerClass = SimpleHTTPRequestHandler,
ServerClass = BaseHTTPServer.HTTPServer):
BaseHTTPServer.test(HandlerClass, ServerClass)
if __name__ == '__main__':
test()

View File

@@ -0,0 +1,704 @@
"""Simple XML-RPC Server.
This module can be used to create simple XML-RPC servers
by creating a server and either installing functions, a
class instance, or by extending the SimpleXMLRPCServer
class.
It can also be used to handle XML-RPC requests in a CGI
environment using CGIXMLRPCRequestHandler.
A list of possible usage patterns follows:
1. Install functions:
server = SimpleXMLRPCServer(("localhost", 8000))
server.register_function(pow)
server.register_function(lambda x,y: x+y, 'add')
server.serve_forever()
2. Install an instance:
class MyFuncs:
def __init__(self):
# make all of the string functions available through
# string.func_name
import string
self.string = string
def _listMethods(self):
# implement this method so that system.listMethods
# knows to advertise the strings methods
return list_public_methods(self) + \
['string.' + method for method in list_public_methods(self.string)]
def pow(self, x, y): return pow(x, y)
def add(self, x, y) : return x + y
server = SimpleXMLRPCServer(("localhost", 8000))
server.register_introspection_functions()
server.register_instance(MyFuncs())
server.serve_forever()
3. Install an instance with custom dispatch method:
class Math:
def _listMethods(self):
# this method must be present for system.listMethods
# to work
return ['add', 'pow']
def _methodHelp(self, method):
# this method must be present for system.methodHelp
# to work
if method == 'add':
return "add(2,3) => 5"
elif method == 'pow':
return "pow(x, y[, z]) => number"
else:
# By convention, return empty
# string if no help is available
return ""
def _dispatch(self, method, params):
if method == 'pow':
return pow(*params)
elif method == 'add':
return params[0] + params[1]
else:
raise 'bad method'
server = SimpleXMLRPCServer(("localhost", 8000))
server.register_introspection_functions()
server.register_instance(Math())
server.serve_forever()
4. Subclass SimpleXMLRPCServer:
class MathServer(SimpleXMLRPCServer):
def _dispatch(self, method, params):
try:
# We are forcing the 'export_' prefix on methods that are
# callable through XML-RPC to prevent potential security
# problems
func = getattr(self, 'export_' + method)
except AttributeError:
raise Exception('method "%s" is not supported' % method)
else:
return func(*params)
def export_add(self, x, y):
return x + y
server = MathServer(("localhost", 8000))
server.serve_forever()
5. CGI script:
server = CGIXMLRPCRequestHandler()
server.register_function(pow)
server.handle_request()
"""
# Written by Brian Quinlan (brian@sweetapp.com).
# Based on code written by Fredrik Lundh.
import xmlrpclib
from xmlrpclib import Fault
import SocketServer
import BaseHTTPServer
import sys
import os
import traceback
import re
try:
import fcntl
except ImportError:
fcntl = None
def resolve_dotted_attribute(obj, attr, allow_dotted_names=True):
"""resolve_dotted_attribute(a, 'b.c.d') => a.b.c.d
Resolves a dotted attribute name to an object. Raises
an AttributeError if any attribute in the chain starts with a '_'.
If the optional allow_dotted_names argument is false, dots are not
supported and this function operates similar to getattr(obj, attr).
"""
if allow_dotted_names:
attrs = attr.split('.')
else:
attrs = [attr]
for i in attrs:
if i.startswith('_'):
raise AttributeError(
'attempt to access private attribute "%s"' % i
)
else:
obj = getattr(obj,i)
return obj
def list_public_methods(obj):
"""Returns a list of attribute strings, found in the specified
object, which represent callable attributes"""
return [member for member in dir(obj)
if not member.startswith('_') and
hasattr(getattr(obj, member), '__call__')]
def remove_duplicates(lst):
"""remove_duplicates([2,2,2,1,3,3]) => [3,1,2]
Returns a copy of a list without duplicates. Every list
item must be hashable and the order of the items in the
resulting list is not defined.
"""
u = {}
for x in lst:
u[x] = 1
return u.keys()
class SimpleXMLRPCDispatcher:
"""Mix-in class that dispatches XML-RPC requests.
This class is used to register XML-RPC method handlers
and then to dispatch them. This class doesn't need to be
instanced directly when used by SimpleXMLRPCServer but it
can be instanced when used by the MultiPathXMLRPCServer.
"""
def __init__(self, allow_none=False, encoding=None):
self.funcs = {}
self.instance = None
self.allow_none = allow_none
self.encoding = encoding
def register_instance(self, instance, allow_dotted_names=False):
"""Registers an instance to respond to XML-RPC requests.
Only one instance can be installed at a time.
If the registered instance has a _dispatch method then that
method will be called with the name of the XML-RPC method and
its parameters as a tuple
e.g. instance._dispatch('add',(2,3))
If the registered instance does not have a _dispatch method
then the instance will be searched to find a matching method
and, if found, will be called. Methods beginning with an '_'
are considered private and will not be called by
SimpleXMLRPCServer.
If a registered function matches a XML-RPC request, then it
will be called instead of the registered instance.
If the optional allow_dotted_names argument is true and the
instance does not have a _dispatch method, method names
containing dots are supported and resolved, as long as none of
the name segments start with an '_'.
*** SECURITY WARNING: ***
Enabling the allow_dotted_names options allows intruders
to access your module's global variables and may allow
intruders to execute arbitrary code on your machine. Only
use this option on a secure, closed network.
"""
self.instance = instance
self.allow_dotted_names = allow_dotted_names
def register_function(self, function, name = None):
"""Registers a function to respond to XML-RPC requests.
The optional name argument can be used to set a Unicode name
for the function.
"""
if name is None:
name = function.__name__
self.funcs[name] = function
def register_introspection_functions(self):
"""Registers the XML-RPC introspection methods in the system
namespace.
see http://xmlrpc.usefulinc.com/doc/reserved.html
"""
self.funcs.update({'system.listMethods' : self.system_listMethods,
'system.methodSignature' : self.system_methodSignature,
'system.methodHelp' : self.system_methodHelp})
def register_multicall_functions(self):
"""Registers the XML-RPC multicall method in the system
namespace.
see http://www.xmlrpc.com/discuss/msgReader$1208"""
self.funcs.update({'system.multicall' : self.system_multicall})
def _marshaled_dispatch(self, data, dispatch_method = None, path = None):
"""Dispatches an XML-RPC method from marshalled (XML) data.
XML-RPC methods are dispatched from the marshalled (XML) data
using the _dispatch method and the result is returned as
marshalled data. For backwards compatibility, a dispatch
function can be provided as an argument (see comment in
SimpleXMLRPCRequestHandler.do_POST) but overriding the
existing method through subclassing is the prefered means
of changing method dispatch behavior.
"""
try:
params, method = xmlrpclib.loads(data)
# generate response
if dispatch_method is not None:
response = dispatch_method(method, params)
else:
response = self._dispatch(method, params)
# wrap response in a singleton tuple
response = (response,)
response = xmlrpclib.dumps(response, methodresponse=1,
allow_none=self.allow_none, encoding=self.encoding)
except Fault, fault:
response = xmlrpclib.dumps(fault, allow_none=self.allow_none,
encoding=self.encoding)
except:
# report exception back to server
exc_type, exc_value, exc_tb = sys.exc_info()
response = xmlrpclib.dumps(
xmlrpclib.Fault(1, "%s:%s" % (exc_type, exc_value)),
encoding=self.encoding, allow_none=self.allow_none,
)
return response
def system_listMethods(self):
"""system.listMethods() => ['add', 'subtract', 'multiple']
Returns a list of the methods supported by the server."""
methods = self.funcs.keys()
if self.instance is not None:
# Instance can implement _listMethod to return a list of
# methods
if hasattr(self.instance, '_listMethods'):
methods = remove_duplicates(
methods + self.instance._listMethods()
)
# if the instance has a _dispatch method then we
# don't have enough information to provide a list
# of methods
elif not hasattr(self.instance, '_dispatch'):
methods = remove_duplicates(
methods + list_public_methods(self.instance)
)
methods.sort()
return methods
def system_methodSignature(self, method_name):
"""system.methodSignature('add') => [double, int, int]
Returns a list describing the signature of the method. In the
above example, the add method takes two integers as arguments
and returns a double result.
This server does NOT support system.methodSignature."""
# See http://xmlrpc.usefulinc.com/doc/sysmethodsig.html
return 'signatures not supported'
def system_methodHelp(self, method_name):
"""system.methodHelp('add') => "Adds two integers together"
Returns a string containing documentation for the specified method."""
method = None
if method_name in self.funcs:
method = self.funcs[method_name]
elif self.instance is not None:
# Instance can implement _methodHelp to return help for a method
if hasattr(self.instance, '_methodHelp'):
return self.instance._methodHelp(method_name)
# if the instance has a _dispatch method then we
# don't have enough information to provide help
elif not hasattr(self.instance, '_dispatch'):
try:
method = resolve_dotted_attribute(
self.instance,
method_name,
self.allow_dotted_names
)
except AttributeError:
pass
# Note that we aren't checking that the method actually
# be a callable object of some kind
if method is None:
return ""
else:
import pydoc
return pydoc.getdoc(method)
def system_multicall(self, call_list):
"""system.multicall([{'methodName': 'add', 'params': [2, 2]}, ...]) => \
[[4], ...]
Allows the caller to package multiple XML-RPC calls into a single
request.
See http://www.xmlrpc.com/discuss/msgReader$1208
"""
results = []
for call in call_list:
method_name = call['methodName']
params = call['params']
try:
# XXX A marshalling error in any response will fail the entire
# multicall. If someone cares they should fix this.
results.append([self._dispatch(method_name, params)])
except Fault, fault:
results.append(
{'faultCode' : fault.faultCode,
'faultString' : fault.faultString}
)
except:
exc_type, exc_value, exc_tb = sys.exc_info()
results.append(
{'faultCode' : 1,
'faultString' : "%s:%s" % (exc_type, exc_value)}
)
return results
def _dispatch(self, method, params):
"""Dispatches the XML-RPC method.
XML-RPC calls are forwarded to a registered function that
matches the called XML-RPC method name. If no such function
exists then the call is forwarded to the registered instance,
if available.
If the registered instance has a _dispatch method then that
method will be called with the name of the XML-RPC method and
its parameters as a tuple
e.g. instance._dispatch('add',(2,3))
If the registered instance does not have a _dispatch method
then the instance will be searched to find a matching method
and, if found, will be called.
Methods beginning with an '_' are considered private and will
not be called.
"""
func = None
try:
# check to see if a matching function has been registered
func = self.funcs[method]
except KeyError:
if self.instance is not None:
# check for a _dispatch method
if hasattr(self.instance, '_dispatch'):
return self.instance._dispatch(method, params)
else:
# call instance method directly
try:
func = resolve_dotted_attribute(
self.instance,
method,
self.allow_dotted_names
)
except AttributeError:
pass
if func is not None:
return func(*params)
else:
raise Exception('method "%s" is not supported' % method)
class SimpleXMLRPCRequestHandler(BaseHTTPServer.BaseHTTPRequestHandler):
"""Simple XML-RPC request handler class.
Handles all HTTP POST requests and attempts to decode them as
XML-RPC requests.
"""
# Class attribute listing the accessible path components;
# paths not on this list will result in a 404 error.
rpc_paths = ('/', '/RPC2')
#if not None, encode responses larger than this, if possible
encode_threshold = 1400 #a common MTU
#Override form StreamRequestHandler: full buffering of output
#and no Nagle.
wbufsize = -1
disable_nagle_algorithm = True
# a re to match a gzip Accept-Encoding
aepattern = re.compile(r"""
\s* ([^\s;]+) \s* #content-coding
(;\s* q \s*=\s* ([0-9\.]+))? #q
""", re.VERBOSE | re.IGNORECASE)
def accept_encodings(self):
r = {}
ae = self.headers.get("Accept-Encoding", "")
for e in ae.split(","):
match = self.aepattern.match(e)
if match:
v = match.group(3)
v = float(v) if v else 1.0
r[match.group(1)] = v
return r
def is_rpc_path_valid(self):
if self.rpc_paths:
return self.path in self.rpc_paths
else:
# If .rpc_paths is empty, just assume all paths are legal
return True
def do_POST(self):
"""Handles the HTTP POST request.
Attempts to interpret all HTTP POST requests as XML-RPC calls,
which are forwarded to the server's _dispatch method for handling.
"""
# Check that the path is legal
if not self.is_rpc_path_valid():
self.report_404()
return
try:
# Get arguments by reading body of request.
# We read this in chunks to avoid straining
# socket.read(); around the 10 or 15Mb mark, some platforms
# begin to have problems (bug #792570).
max_chunk_size = 10*1024*1024
size_remaining = int(self.headers["content-length"])
L = []
while size_remaining:
chunk_size = min(size_remaining, max_chunk_size)
L.append(self.rfile.read(chunk_size))
size_remaining -= len(L[-1])
data = ''.join(L)
data = self.decode_request_content(data)
if data is None:
return #response has been sent
# In previous versions of SimpleXMLRPCServer, _dispatch
# could be overridden in this class, instead of in
# SimpleXMLRPCDispatcher. To maintain backwards compatibility,
# check to see if a subclass implements _dispatch and dispatch
# using that method if present.
response = self.server._marshaled_dispatch(
data, getattr(self, '_dispatch', None), self.path
)
except Exception, e: # This should only happen if the module is buggy
# internal error, report as HTTP server error
self.send_response(500)
# Send information about the exception if requested
if hasattr(self.server, '_send_traceback_header') and \
self.server._send_traceback_header:
self.send_header("X-exception", str(e))
self.send_header("X-traceback", traceback.format_exc())
self.send_header("Content-length", "0")
self.end_headers()
else:
# got a valid XML RPC response
self.send_response(200)
self.send_header("Content-type", "text/xml")
if self.encode_threshold is not None:
if len(response) > self.encode_threshold:
q = self.accept_encodings().get("gzip", 0)
if q:
try:
response = xmlrpclib.gzip_encode(response)
self.send_header("Content-Encoding", "gzip")
except NotImplementedError:
pass
self.send_header("Content-length", str(len(response)))
self.end_headers()
self.wfile.write(response)
def decode_request_content(self, data):
#support gzip encoding of request
encoding = self.headers.get("content-encoding", "identity").lower()
if encoding == "identity":
return data
if encoding == "gzip":
try:
return xmlrpclib.gzip_decode(data)
except NotImplementedError:
self.send_response(501, "encoding %r not supported" % encoding)
except ValueError:
self.send_response(400, "error decoding gzip content")
else:
self.send_response(501, "encoding %r not supported" % encoding)
self.send_header("Content-length", "0")
self.end_headers()
def report_404 (self):
# Report a 404 error
self.send_response(404)
response = 'No such page'
self.send_header("Content-type", "text/plain")
self.send_header("Content-length", str(len(response)))
self.end_headers()
self.wfile.write(response)
def log_request(self, code='-', size='-'):
"""Selectively log an accepted request."""
if self.server.logRequests:
BaseHTTPServer.BaseHTTPRequestHandler.log_request(self, code, size)
class SimpleXMLRPCServer(SocketServer.TCPServer,
SimpleXMLRPCDispatcher):
"""Simple XML-RPC server.
Simple XML-RPC server that allows functions and a single instance
to be installed to handle requests. The default implementation
attempts to dispatch XML-RPC calls to the functions or instance
installed in the server. Override the _dispatch method inhereted
from SimpleXMLRPCDispatcher to change this behavior.
"""
allow_reuse_address = True
# Warning: this is for debugging purposes only! Never set this to True in
# production code, as will be sending out sensitive information (exception
# and stack trace details) when exceptions are raised inside
# SimpleXMLRPCRequestHandler.do_POST
_send_traceback_header = False
def __init__(self, addr, requestHandler=SimpleXMLRPCRequestHandler,
logRequests=True, allow_none=False, encoding=None, bind_and_activate=True):
self.logRequests = logRequests
SimpleXMLRPCDispatcher.__init__(self, allow_none, encoding)
SocketServer.TCPServer.__init__(self, addr, requestHandler, bind_and_activate)
# [Bug #1222790] If possible, set close-on-exec flag; if a
# method spawns a subprocess, the subprocess shouldn't have
# the listening socket open.
if fcntl is not None and hasattr(fcntl, 'FD_CLOEXEC'):
flags = fcntl.fcntl(self.fileno(), fcntl.F_GETFD)
flags |= fcntl.FD_CLOEXEC
fcntl.fcntl(self.fileno(), fcntl.F_SETFD, flags)
class MultiPathXMLRPCServer(SimpleXMLRPCServer):
"""Multipath XML-RPC Server
This specialization of SimpleXMLRPCServer allows the user to create
multiple Dispatcher instances and assign them to different
HTTP request paths. This makes it possible to run two or more
'virtual XML-RPC servers' at the same port.
Make sure that the requestHandler accepts the paths in question.
"""
def __init__(self, addr, requestHandler=SimpleXMLRPCRequestHandler,
logRequests=True, allow_none=False, encoding=None, bind_and_activate=True):
SimpleXMLRPCServer.__init__(self, addr, requestHandler, logRequests, allow_none,
encoding, bind_and_activate)
self.dispatchers = {}
self.allow_none = allow_none
self.encoding = encoding
def add_dispatcher(self, path, dispatcher):
self.dispatchers[path] = dispatcher
return dispatcher
def get_dispatcher(self, path):
return self.dispatchers[path]
def _marshaled_dispatch(self, data, dispatch_method = None, path = None):
try:
response = self.dispatchers[path]._marshaled_dispatch(
data, dispatch_method, path)
except:
# report low level exception back to server
# (each dispatcher should have handled their own
# exceptions)
exc_type, exc_value = sys.exc_info()[:2]
response = xmlrpclib.dumps(
xmlrpclib.Fault(1, "%s:%s" % (exc_type, exc_value)),
encoding=self.encoding, allow_none=self.allow_none)
return response
class CGIXMLRPCRequestHandler(SimpleXMLRPCDispatcher):
"""Simple handler for XML-RPC data passed through CGI."""
def __init__(self, allow_none=False, encoding=None):
SimpleXMLRPCDispatcher.__init__(self, allow_none, encoding)
def handle_xmlrpc(self, request_text):
"""Handle a single XML-RPC request"""
response = self._marshaled_dispatch(request_text)
print 'Content-Type: text/xml'
print 'Content-Length: %d' % len(response)
print
sys.stdout.write(response)
def handle_get(self):
"""Handle a single HTTP GET request.
Default implementation indicates an error because
XML-RPC uses the POST method.
"""
code = 400
message, explain = \
BaseHTTPServer.BaseHTTPRequestHandler.responses[code]
response = BaseHTTPServer.DEFAULT_ERROR_MESSAGE % \
{
'code' : code,
'message' : message,
'explain' : explain
}
print 'Status: %d %s' % (code, message)
print 'Content-Type: %s' % BaseHTTPServer.DEFAULT_ERROR_CONTENT_TYPE
print 'Content-Length: %d' % len(response)
print
sys.stdout.write(response)
def handle_request(self, request_text = None):
"""Handle a single XML-RPC request passed through a CGI post method.
If no XML data is given then it is read from stdin. The resulting
XML-RPC response is printed to stdout along with the correct HTTP
headers.
"""
if request_text is None and \
os.environ.get('REQUEST_METHOD', None) == 'GET':
self.handle_get()
else:
# POST data is normally available through stdin
try:
length = int(os.environ.get('CONTENT_LENGTH', None))
except (TypeError, ValueError):
length = -1
if request_text is None:
request_text = sys.stdin.read(length)
self.handle_xmlrpc(request_text)
if __name__ == '__main__':
print 'Running XML-RPC server on port 8000'
server = SimpleXMLRPCServer(("localhost", 8000))
server.register_function(pow)
server.register_function(lambda x,y: x+y, 'add')
server.serve_forever()

View File

@@ -0,0 +1,716 @@
"""Generic socket server classes.
This module tries to capture the various aspects of defining a server:
For socket-based servers:
- address family:
- AF_INET{,6}: IP (Internet Protocol) sockets (default)
- AF_UNIX: Unix domain sockets
- others, e.g. AF_DECNET are conceivable (see <socket.h>
- socket type:
- SOCK_STREAM (reliable stream, e.g. TCP)
- SOCK_DGRAM (datagrams, e.g. UDP)
For request-based servers (including socket-based):
- client address verification before further looking at the request
(This is actually a hook for any processing that needs to look
at the request before anything else, e.g. logging)
- how to handle multiple requests:
- synchronous (one request is handled at a time)
- forking (each request is handled by a new process)
- threading (each request is handled by a new thread)
The classes in this module favor the server type that is simplest to
write: a synchronous TCP/IP server. This is bad class design, but
save some typing. (There's also the issue that a deep class hierarchy
slows down method lookups.)
There are five classes in an inheritance diagram, four of which represent
synchronous servers of four types:
+------------+
| BaseServer |
+------------+
|
v
+-----------+ +------------------+
| TCPServer |------->| UnixStreamServer |
+-----------+ +------------------+
|
v
+-----------+ +--------------------+
| UDPServer |------->| UnixDatagramServer |
+-----------+ +--------------------+
Note that UnixDatagramServer derives from UDPServer, not from
UnixStreamServer -- the only difference between an IP and a Unix
stream server is the address family, which is simply repeated in both
unix server classes.
Forking and threading versions of each type of server can be created
using the ForkingMixIn and ThreadingMixIn mix-in classes. For
instance, a threading UDP server class is created as follows:
class ThreadingUDPServer(ThreadingMixIn, UDPServer): pass
The Mix-in class must come first, since it overrides a method defined
in UDPServer! Setting the various member variables also changes
the behavior of the underlying server mechanism.
To implement a service, you must derive a class from
BaseRequestHandler and redefine its handle() method. You can then run
various versions of the service by combining one of the server classes
with your request handler class.
The request handler class must be different for datagram or stream
services. This can be hidden by using the request handler
subclasses StreamRequestHandler or DatagramRequestHandler.
Of course, you still have to use your head!
For instance, it makes no sense to use a forking server if the service
contains state in memory that can be modified by requests (since the
modifications in the child process would never reach the initial state
kept in the parent process and passed to each child). In this case,
you can use a threading server, but you will probably have to use
locks to avoid two requests that come in nearly simultaneous to apply
conflicting changes to the server state.
On the other hand, if you are building e.g. an HTTP server, where all
data is stored externally (e.g. in the file system), a synchronous
class will essentially render the service "deaf" while one request is
being handled -- which may be for a very long time if a client is slow
to reqd all the data it has requested. Here a threading or forking
server is appropriate.
In some cases, it may be appropriate to process part of a request
synchronously, but to finish processing in a forked child depending on
the request data. This can be implemented by using a synchronous
server and doing an explicit fork in the request handler class
handle() method.
Another approach to handling multiple simultaneous requests in an
environment that supports neither threads nor fork (or where these are
too expensive or inappropriate for the service) is to maintain an
explicit table of partially finished requests and to use select() to
decide which request to work on next (or whether to handle a new
incoming request). This is particularly important for stream services
where each client can potentially be connected for a long time (if
threads or subprocesses cannot be used).
Future work:
- Standard classes for Sun RPC (which uses either UDP or TCP)
- Standard mix-in classes to implement various authentication
and encryption schemes
- Standard framework for select-based multiplexing
XXX Open problems:
- What to do with out-of-band data?
BaseServer:
- split generic "request" functionality out into BaseServer class.
Copyright (C) 2000 Luke Kenneth Casson Leighton <lkcl@samba.org>
example: read entries from a SQL database (requires overriding
get_request() to return a table entry from the database).
entry is processed by a RequestHandlerClass.
"""
# Author of the BaseServer patch: Luke Kenneth Casson Leighton
# XXX Warning!
# There is a test suite for this module, but it cannot be run by the
# standard regression test.
# To run it manually, run Lib/test/test_socketserver.py.
__version__ = "0.4"
import socket
import select
import sys
import os
try:
import threading
except ImportError:
import dummy_threading as threading
__all__ = ["TCPServer","UDPServer","ForkingUDPServer","ForkingTCPServer",
"ThreadingUDPServer","ThreadingTCPServer","BaseRequestHandler",
"StreamRequestHandler","DatagramRequestHandler",
"ThreadingMixIn", "ForkingMixIn"]
if hasattr(socket, "AF_UNIX"):
__all__.extend(["UnixStreamServer","UnixDatagramServer",
"ThreadingUnixStreamServer",
"ThreadingUnixDatagramServer"])
class BaseServer:
"""Base class for server classes.
Methods for the caller:
- __init__(server_address, RequestHandlerClass)
- serve_forever(poll_interval=0.5)
- shutdown()
- handle_request() # if you do not use serve_forever()
- fileno() -> int # for select()
Methods that may be overridden:
- server_bind()
- server_activate()
- get_request() -> request, client_address
- handle_timeout()
- verify_request(request, client_address)
- server_close()
- process_request(request, client_address)
- shutdown_request(request)
- close_request(request)
- handle_error()
Methods for derived classes:
- finish_request(request, client_address)
Class variables that may be overridden by derived classes or
instances:
- timeout
- address_family
- socket_type
- allow_reuse_address
Instance variables:
- RequestHandlerClass
- socket
"""
timeout = None
def __init__(self, server_address, RequestHandlerClass):
"""Constructor. May be extended, do not override."""
self.server_address = server_address
self.RequestHandlerClass = RequestHandlerClass
self.__is_shut_down = threading.Event()
self.__shutdown_request = False
def server_activate(self):
"""Called by constructor to activate the server.
May be overridden.
"""
pass
def serve_forever(self, poll_interval=0.5):
"""Handle one request at a time until shutdown.
Polls for shutdown every poll_interval seconds. Ignores
self.timeout. If you need to do periodic tasks, do them in
another thread.
"""
self.__is_shut_down.clear()
try:
while not self.__shutdown_request:
# XXX: Consider using another file descriptor or
# connecting to the socket to wake this up instead of
# polling. Polling reduces our responsiveness to a
# shutdown request and wastes cpu at all other times.
r, w, e = select.select([self], [], [], poll_interval)
if self in r:
self._handle_request_noblock()
finally:
self.__shutdown_request = False
self.__is_shut_down.set()
def shutdown(self):
"""Stops the serve_forever loop.
Blocks until the loop has finished. This must be called while
serve_forever() is running in another thread, or it will
deadlock.
"""
self.__shutdown_request = True
self.__is_shut_down.wait()
# The distinction between handling, getting, processing and
# finishing a request is fairly arbitrary. Remember:
#
# - handle_request() is the top-level call. It calls
# select, get_request(), verify_request() and process_request()
# - get_request() is different for stream or datagram sockets
# - process_request() is the place that may fork a new process
# or create a new thread to finish the request
# - finish_request() instantiates the request handler class;
# this constructor will handle the request all by itself
def handle_request(self):
"""Handle one request, possibly blocking.
Respects self.timeout.
"""
# Support people who used socket.settimeout() to escape
# handle_request before self.timeout was available.
timeout = self.socket.gettimeout()
if timeout is None:
timeout = self.timeout
elif self.timeout is not None:
timeout = min(timeout, self.timeout)
fd_sets = select.select([self], [], [], timeout)
if not fd_sets[0]:
self.handle_timeout()
return
self._handle_request_noblock()
def _handle_request_noblock(self):
"""Handle one request, without blocking.
I assume that select.select has returned that the socket is
readable before this function was called, so there should be
no risk of blocking in get_request().
"""
try:
request, client_address = self.get_request()
except socket.error:
return
if self.verify_request(request, client_address):
try:
self.process_request(request, client_address)
except:
self.handle_error(request, client_address)
self.shutdown_request(request)
def handle_timeout(self):
"""Called if no new request arrives within self.timeout.
Overridden by ForkingMixIn.
"""
pass
def verify_request(self, request, client_address):
"""Verify the request. May be overridden.
Return True if we should proceed with this request.
"""
return True
def process_request(self, request, client_address):
"""Call finish_request.
Overridden by ForkingMixIn and ThreadingMixIn.
"""
self.finish_request(request, client_address)
self.shutdown_request(request)
def server_close(self):
"""Called to clean-up the server.
May be overridden.
"""
pass
def finish_request(self, request, client_address):
"""Finish one request by instantiating RequestHandlerClass."""
self.RequestHandlerClass(request, client_address, self)
def shutdown_request(self, request):
"""Called to shutdown and close an individual request."""
self.close_request(request)
def close_request(self, request):
"""Called to clean up an individual request."""
pass
def handle_error(self, request, client_address):
"""Handle an error gracefully. May be overridden.
The default is to print a traceback and continue.
"""
print '-'*40
print 'Exception happened during processing of request from',
print client_address
import traceback
traceback.print_exc() # XXX But this goes to stderr!
print '-'*40
class TCPServer(BaseServer):
"""Base class for various socket-based server classes.
Defaults to synchronous IP stream (i.e., TCP).
Methods for the caller:
- __init__(server_address, RequestHandlerClass, bind_and_activate=True)
- serve_forever(poll_interval=0.5)
- shutdown()
- handle_request() # if you don't use serve_forever()
- fileno() -> int # for select()
Methods that may be overridden:
- server_bind()
- server_activate()
- get_request() -> request, client_address
- handle_timeout()
- verify_request(request, client_address)
- process_request(request, client_address)
- shutdown_request(request)
- close_request(request)
- handle_error()
Methods for derived classes:
- finish_request(request, client_address)
Class variables that may be overridden by derived classes or
instances:
- timeout
- address_family
- socket_type
- request_queue_size (only for stream sockets)
- allow_reuse_address
Instance variables:
- server_address
- RequestHandlerClass
- socket
"""
address_family = socket.AF_INET
socket_type = socket.SOCK_STREAM
request_queue_size = 5
allow_reuse_address = False
def __init__(self, server_address, RequestHandlerClass, bind_and_activate=True):
"""Constructor. May be extended, do not override."""
BaseServer.__init__(self, server_address, RequestHandlerClass)
self.socket = socket.socket(self.address_family,
self.socket_type)
if bind_and_activate:
self.server_bind()
self.server_activate()
def server_bind(self):
"""Called by constructor to bind the socket.
May be overridden.
"""
if self.allow_reuse_address:
self.socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
self.socket.bind(self.server_address)
self.server_address = self.socket.getsockname()
def server_activate(self):
"""Called by constructor to activate the server.
May be overridden.
"""
self.socket.listen(self.request_queue_size)
def server_close(self):
"""Called to clean-up the server.
May be overridden.
"""
self.socket.close()
def fileno(self):
"""Return socket file number.
Interface required by select().
"""
return self.socket.fileno()
def get_request(self):
"""Get the request and client address from the socket.
May be overridden.
"""
return self.socket.accept()
def shutdown_request(self, request):
"""Called to shutdown and close an individual request."""
try:
#explicitly shutdown. socket.close() merely releases
#the socket and waits for GC to perform the actual close.
request.shutdown(socket.SHUT_WR)
except socket.error:
pass #some platforms may raise ENOTCONN here
self.close_request(request)
def close_request(self, request):
"""Called to clean up an individual request."""
request.close()
class UDPServer(TCPServer):
"""UDP server class."""
allow_reuse_address = False
socket_type = socket.SOCK_DGRAM
max_packet_size = 8192
def get_request(self):
data, client_addr = self.socket.recvfrom(self.max_packet_size)
return (data, self.socket), client_addr
def server_activate(self):
# No need to call listen() for UDP.
pass
def shutdown_request(self, request):
# No need to shutdown anything.
self.close_request(request)
def close_request(self, request):
# No need to close anything.
pass
class ForkingMixIn:
"""Mix-in class to handle each request in a new process."""
timeout = 300
active_children = None
max_children = 40
def collect_children(self):
"""Internal routine to wait for children that have exited."""
if self.active_children is None: return
while len(self.active_children) >= self.max_children:
# XXX: This will wait for any child process, not just ones
# spawned by this library. This could confuse other
# libraries that expect to be able to wait for their own
# children.
try:
pid, status = os.waitpid(0, 0)
except os.error:
pid = None
if pid not in self.active_children: continue
self.active_children.remove(pid)
# XXX: This loop runs more system calls than it ought
# to. There should be a way to put the active_children into a
# process group and then use os.waitpid(-pgid) to wait for any
# of that set, but I couldn't find a way to allocate pgids
# that couldn't collide.
for child in self.active_children:
try:
pid, status = os.waitpid(child, os.WNOHANG)
except os.error:
pid = None
if not pid: continue
try:
self.active_children.remove(pid)
except ValueError, e:
raise ValueError('%s. x=%d and list=%r' % (e.message, pid,
self.active_children))
def handle_timeout(self):
"""Wait for zombies after self.timeout seconds of inactivity.
May be extended, do not override.
"""
self.collect_children()
def process_request(self, request, client_address):
"""Fork a new subprocess to process the request."""
self.collect_children()
pid = os.fork()
if pid:
# Parent process
if self.active_children is None:
self.active_children = []
self.active_children.append(pid)
self.close_request(request) #close handle in parent process
return
else:
# Child process.
# This must never return, hence os._exit()!
try:
self.finish_request(request, client_address)
self.shutdown_request(request)
os._exit(0)
except:
try:
self.handle_error(request, client_address)
self.shutdown_request(request)
finally:
os._exit(1)
class ThreadingMixIn:
"""Mix-in class to handle each request in a new thread."""
# Decides how threads will act upon termination of the
# main process
daemon_threads = False
def process_request_thread(self, request, client_address):
"""Same as in BaseServer but as a thread.
In addition, exception handling is done here.
"""
try:
self.finish_request(request, client_address)
self.shutdown_request(request)
except:
self.handle_error(request, client_address)
self.shutdown_request(request)
def process_request(self, request, client_address):
"""Start a new thread to process the request."""
t = threading.Thread(target = self.process_request_thread,
args = (request, client_address))
if self.daemon_threads:
t.setDaemon (1)
t.start()
class ForkingUDPServer(ForkingMixIn, UDPServer): pass
class ForkingTCPServer(ForkingMixIn, TCPServer): pass
class ThreadingUDPServer(ThreadingMixIn, UDPServer): pass
class ThreadingTCPServer(ThreadingMixIn, TCPServer): pass
if hasattr(socket, 'AF_UNIX'):
class UnixStreamServer(TCPServer):
address_family = socket.AF_UNIX
class UnixDatagramServer(UDPServer):
address_family = socket.AF_UNIX
class ThreadingUnixStreamServer(ThreadingMixIn, UnixStreamServer): pass
class ThreadingUnixDatagramServer(ThreadingMixIn, UnixDatagramServer): pass
class BaseRequestHandler:
"""Base class for request handler classes.
This class is instantiated for each request to be handled. The
constructor sets the instance variables request, client_address
and server, and then calls the handle() method. To implement a
specific service, all you need to do is to derive a class which
defines a handle() method.
The handle() method can find the request as self.request, the
client address as self.client_address, and the server (in case it
needs access to per-server information) as self.server. Since a
separate instance is created for each request, the handle() method
can define arbitrary other instance variariables.
"""
def __init__(self, request, client_address, server):
self.request = request
self.client_address = client_address
self.server = server
self.setup()
try:
self.handle()
finally:
self.finish()
def setup(self):
pass
def handle(self):
pass
def finish(self):
pass
# The following two classes make it possible to use the same service
# class for stream or datagram servers.
# Each class sets up these instance variables:
# - rfile: a file object from which receives the request is read
# - wfile: a file object to which the reply is written
# When the handle() method returns, wfile is flushed properly
class StreamRequestHandler(BaseRequestHandler):
"""Define self.rfile and self.wfile for stream sockets."""
# Default buffer sizes for rfile, wfile.
# We default rfile to buffered because otherwise it could be
# really slow for large data (a getc() call per byte); we make
# wfile unbuffered because (a) often after a write() we want to
# read and we need to flush the line; (b) big writes to unbuffered
# files are typically optimized by stdio even when big reads
# aren't.
rbufsize = -1
wbufsize = 0
# A timeout to apply to the request socket, if not None.
timeout = None
# Disable nagle algoritm for this socket, if True.
# Use only when wbufsize != 0, to avoid small packets.
disable_nagle_algorithm = False
def setup(self):
self.connection = self.request
if self.timeout is not None:
self.connection.settimeout(self.timeout)
if self.disable_nagle_algorithm:
self.connection.setsockopt(socket.IPPROTO_TCP,
socket.TCP_NODELAY, True)
self.rfile = self.connection.makefile('rb', self.rbufsize)
self.wfile = self.connection.makefile('wb', self.wbufsize)
def finish(self):
if not self.wfile.closed:
self.wfile.flush()
self.wfile.close()
self.rfile.close()
class DatagramRequestHandler(BaseRequestHandler):
# XXX Regrettably, I cannot get this working on Linux;
# s.recvfrom() doesn't return a meaningful client address.
"""Define self.rfile and self.wfile for datagram sockets."""
def setup(self):
try:
from cStringIO import StringIO
except ImportError:
from StringIO import StringIO
self.packet, self.socket = self.request
self.rfile = StringIO(self.packet)
self.wfile = StringIO()
def finish(self):
self.socket.sendto(self.wfile.getvalue(), self.client_address)

View File

@@ -0,0 +1,323 @@
r"""File-like objects that read from or write to a string buffer.
This implements (nearly) all stdio methods.
f = StringIO() # ready for writing
f = StringIO(buf) # ready for reading
f.close() # explicitly release resources held
flag = f.isatty() # always false
pos = f.tell() # get current position
f.seek(pos) # set current position
f.seek(pos, mode) # mode 0: absolute; 1: relative; 2: relative to EOF
buf = f.read() # read until EOF
buf = f.read(n) # read up to n bytes
buf = f.readline() # read until end of line ('\n') or EOF
list = f.readlines()# list of f.readline() results until EOF
f.truncate([size]) # truncate file at to at most size (default: current pos)
f.write(buf) # write at current position
f.writelines(list) # for line in list: f.write(line)
f.getvalue() # return whole file's contents as a string
Notes:
- Using a real file is often faster (but less convenient).
- There's also a much faster implementation in C, called cStringIO, but
it's not subclassable.
- fileno() is left unimplemented so that code which uses it triggers
an exception early.
- Seeking far beyond EOF and then writing will insert real null
bytes that occupy space in the buffer.
- There's a simple test set (see end of this file).
"""
try:
from errno import EINVAL
except ImportError:
EINVAL = 22
__all__ = ["StringIO"]
def _complain_ifclosed(closed):
if closed:
raise ValueError, "I/O operation on closed file"
class StringIO:
"""class StringIO([buffer])
When a StringIO object is created, it can be initialized to an existing
string by passing the string to the constructor. If no string is given,
the StringIO will start empty.
The StringIO object can accept either Unicode or 8-bit strings, but
mixing the two may take some care. If both are used, 8-bit strings that
cannot be interpreted as 7-bit ASCII (that use the 8th bit) will cause
a UnicodeError to be raised when getvalue() is called.
"""
def __init__(self, buf = ''):
# Force self.buf to be a string or unicode
if not isinstance(buf, basestring):
buf = str(buf)
self.buf = buf
self.len = len(buf)
self.buflist = []
self.pos = 0
self.closed = False
self.softspace = 0
def __iter__(self):
return self
def next(self):
"""A file object is its own iterator, for example iter(f) returns f
(unless f is closed). When a file is used as an iterator, typically
in a for loop (for example, for line in f: print line), the next()
method is called repeatedly. This method returns the next input line,
or raises StopIteration when EOF is hit.
"""
_complain_ifclosed(self.closed)
r = self.readline()
if not r:
raise StopIteration
return r
def close(self):
"""Free the memory buffer.
"""
if not self.closed:
self.closed = True
del self.buf, self.pos
def isatty(self):
"""Returns False because StringIO objects are not connected to a
tty-like device.
"""
_complain_ifclosed(self.closed)
return False
def seek(self, pos, mode = 0):
"""Set the file's current position.
The mode argument is optional and defaults to 0 (absolute file
positioning); other values are 1 (seek relative to the current
position) and 2 (seek relative to the file's end).
There is no return value.
"""
_complain_ifclosed(self.closed)
if self.buflist:
self.buf += ''.join(self.buflist)
self.buflist = []
if mode == 1:
pos += self.pos
elif mode == 2:
pos += self.len
self.pos = max(0, pos)
def tell(self):
"""Return the file's current position."""
_complain_ifclosed(self.closed)
return self.pos
def read(self, n = -1):
"""Read at most size bytes from the file
(less if the read hits EOF before obtaining size bytes).
If the size argument is negative or omitted, read all data until EOF
is reached. The bytes are returned as a string object. An empty
string is returned when EOF is encountered immediately.
"""
_complain_ifclosed(self.closed)
if self.buflist:
self.buf += ''.join(self.buflist)
self.buflist = []
if n is None or n < 0:
newpos = self.len
else:
newpos = min(self.pos+n, self.len)
r = self.buf[self.pos:newpos]
self.pos = newpos
return r
def readline(self, length=None):
r"""Read one entire line from the file.
A trailing newline character is kept in the string (but may be absent
when a file ends with an incomplete line). If the size argument is
present and non-negative, it is a maximum byte count (including the
trailing newline) and an incomplete line may be returned.
An empty string is returned only when EOF is encountered immediately.
Note: Unlike stdio's fgets(), the returned string contains null
characters ('\0') if they occurred in the input.
"""
_complain_ifclosed(self.closed)
if self.buflist:
self.buf += ''.join(self.buflist)
self.buflist = []
i = self.buf.find('\n', self.pos)
if i < 0:
newpos = self.len
else:
newpos = i+1
if length is not None and length > 0:
if self.pos + length < newpos:
newpos = self.pos + length
r = self.buf[self.pos:newpos]
self.pos = newpos
return r
def readlines(self, sizehint = 0):
"""Read until EOF using readline() and return a list containing the
lines thus read.
If the optional sizehint argument is present, instead of reading up
to EOF, whole lines totalling approximately sizehint bytes (or more
to accommodate a final whole line).
"""
total = 0
lines = []
line = self.readline()
while line:
lines.append(line)
total += len(line)
if 0 < sizehint <= total:
break
line = self.readline()
return lines
def truncate(self, size=None):
"""Truncate the file's size.
If the optional size argument is present, the file is truncated to
(at most) that size. The size defaults to the current position.
The current file position is not changed unless the position
is beyond the new file size.
If the specified size exceeds the file's current size, the
file remains unchanged.
"""
_complain_ifclosed(self.closed)
if size is None:
size = self.pos
elif size < 0:
raise IOError(EINVAL, "Negative size not allowed")
elif size < self.pos:
self.pos = size
self.buf = self.getvalue()[:size]
self.len = size
def write(self, s):
"""Write a string to the file.
There is no return value.
"""
_complain_ifclosed(self.closed)
if not s: return
# Force s to be a string or unicode
if not isinstance(s, basestring):
s = str(s)
spos = self.pos
slen = self.len
if spos == slen:
self.buflist.append(s)
self.len = self.pos = spos + len(s)
return
if spos > slen:
self.buflist.append('\0'*(spos - slen))
slen = spos
newpos = spos + len(s)
if spos < slen:
if self.buflist:
self.buf += ''.join(self.buflist)
self.buflist = [self.buf[:spos], s, self.buf[newpos:]]
self.buf = ''
if newpos > slen:
slen = newpos
else:
self.buflist.append(s)
slen = newpos
self.len = slen
self.pos = newpos
def writelines(self, iterable):
"""Write a sequence of strings to the file. The sequence can be any
iterable object producing strings, typically a list of strings. There
is no return value.
(The name is intended to match readlines(); writelines() does not add
line separators.)
"""
write = self.write
for line in iterable:
write(line)
def flush(self):
"""Flush the internal buffer
"""
_complain_ifclosed(self.closed)
def getvalue(self):
"""
Retrieve the entire contents of the "file" at any time before
the StringIO object's close() method is called.
The StringIO object can accept either Unicode or 8-bit strings,
but mixing the two may take some care. If both are used, 8-bit
strings that cannot be interpreted as 7-bit ASCII (that use the
8th bit) will cause a UnicodeError to be raised when getvalue()
is called.
"""
if self.buflist:
self.buf += ''.join(self.buflist)
self.buflist = []
return self.buf
# A little test suite
def test():
import sys
if sys.argv[1:]:
file = sys.argv[1]
else:
file = '/etc/passwd'
lines = open(file, 'r').readlines()
text = open(file, 'r').read()
f = StringIO()
for line in lines[:-2]:
f.write(line)
f.writelines(lines[-2:])
if f.getvalue() != text:
raise RuntimeError, 'write failed'
length = f.tell()
print 'File length =', length
f.seek(len(lines[0]))
f.write(lines[1])
f.seek(0)
print 'First line =', repr(f.readline())
print 'Position =', f.tell()
line = f.readline()
print 'Second line =', repr(line)
f.seek(-len(line), 1)
line2 = f.read(len(line))
if line != line2:
raise RuntimeError, 'bad result after seek back'
f.seek(len(line2), 1)
list = f.readlines()
line = list[-1]
f.seek(f.tell() - len(line))
line2 = f.read()
if line != line2:
raise RuntimeError, 'bad result after seek back from EOF'
print 'Read', len(list), 'more lines'
print 'File length =', f.tell()
if f.tell() != length:
raise RuntimeError, 'bad length'
f.truncate(length/2)
f.seek(0, 2)
print 'Truncated length =', f.tell()
if f.tell() != length/2:
raise RuntimeError, 'truncate did not adjust length'
f.close()
if __name__ == '__main__':
test()

View File

@@ -0,0 +1,180 @@
"""A more or less complete user-defined wrapper around dictionary objects."""
class UserDict:
def __init__(self, dict=None, **kwargs):
self.data = {}
if dict is not None:
self.update(dict)
if len(kwargs):
self.update(kwargs)
def __repr__(self): return repr(self.data)
def __cmp__(self, dict):
if isinstance(dict, UserDict):
return cmp(self.data, dict.data)
else:
return cmp(self.data, dict)
__hash__ = None # Avoid Py3k warning
def __len__(self): return len(self.data)
def __getitem__(self, key):
if key in self.data:
return self.data[key]
if hasattr(self.__class__, "__missing__"):
return self.__class__.__missing__(self, key)
raise KeyError(key)
def __setitem__(self, key, item): self.data[key] = item
def __delitem__(self, key): del self.data[key]
def clear(self): self.data.clear()
def copy(self):
if self.__class__ is UserDict:
return UserDict(self.data.copy())
import copy
data = self.data
try:
self.data = {}
c = copy.copy(self)
finally:
self.data = data
c.update(self)
return c
def keys(self): return self.data.keys()
def items(self): return self.data.items()
def iteritems(self): return self.data.iteritems()
def iterkeys(self): return self.data.iterkeys()
def itervalues(self): return self.data.itervalues()
def values(self): return self.data.values()
def has_key(self, key): return key in self.data
def update(self, dict=None, **kwargs):
if dict is None:
pass
elif isinstance(dict, UserDict):
self.data.update(dict.data)
elif isinstance(dict, type({})) or not hasattr(dict, 'items'):
self.data.update(dict)
else:
for k, v in dict.items():
self[k] = v
if len(kwargs):
self.data.update(kwargs)
def get(self, key, failobj=None):
if key not in self:
return failobj
return self[key]
def setdefault(self, key, failobj=None):
if key not in self:
self[key] = failobj
return self[key]
def pop(self, key, *args):
return self.data.pop(key, *args)
def popitem(self):
return self.data.popitem()
def __contains__(self, key):
return key in self.data
@classmethod
def fromkeys(cls, iterable, value=None):
d = cls()
for key in iterable:
d[key] = value
return d
class IterableUserDict(UserDict):
def __iter__(self):
return iter(self.data)
import _abcoll
_abcoll.MutableMapping.register(IterableUserDict)
class DictMixin:
# Mixin defining all dictionary methods for classes that already have
# a minimum dictionary interface including getitem, setitem, delitem,
# and keys. Without knowledge of the subclass constructor, the mixin
# does not define __init__() or copy(). In addition to the four base
# methods, progressively more efficiency comes with defining
# __contains__(), __iter__(), and iteritems().
# second level definitions support higher levels
def __iter__(self):
for k in self.keys():
yield k
def has_key(self, key):
try:
self[key]
except KeyError:
return False
return True
def __contains__(self, key):
return self.has_key(key)
# third level takes advantage of second level definitions
def iteritems(self):
for k in self:
yield (k, self[k])
def iterkeys(self):
return self.__iter__()
# fourth level uses definitions from lower levels
def itervalues(self):
for _, v in self.iteritems():
yield v
def values(self):
return [v for _, v in self.iteritems()]
def items(self):
return list(self.iteritems())
def clear(self):
for key in self.keys():
del self[key]
def setdefault(self, key, default=None):
try:
return self[key]
except KeyError:
self[key] = default
return default
def pop(self, key, *args):
if len(args) > 1:
raise TypeError, "pop expected at most 2 arguments, got "\
+ repr(1 + len(args))
try:
value = self[key]
except KeyError:
if args:
return args[0]
raise
del self[key]
return value
def popitem(self):
try:
k, v = self.iteritems().next()
except StopIteration:
raise KeyError, 'container is empty'
del self[k]
return (k, v)
def update(self, other=None, **kwargs):
# Make progressively weaker assumptions about "other"
if other is None:
pass
elif hasattr(other, 'iteritems'): # iteritems saves memory and lookups
for k, v in other.iteritems():
self[k] = v
elif hasattr(other, 'keys'):
for k in other.keys():
self[k] = other[k]
else:
for k, v in other:
self[k] = v
if kwargs:
self.update(kwargs)
def get(self, key, default=None):
try:
return self[key]
except KeyError:
return default
def __repr__(self):
return repr(dict(self.iteritems()))
def __cmp__(self, other):
if other is None:
return 1
if isinstance(other, DictMixin):
other = dict(other.iteritems())
return cmp(dict(self.iteritems()), other)
def __len__(self):
return len(self.keys())

View File

@@ -0,0 +1,88 @@
"""A more or less complete user-defined wrapper around list objects."""
import collections
class UserList(collections.MutableSequence):
def __init__(self, initlist=None):
self.data = []
if initlist is not None:
# XXX should this accept an arbitrary sequence?
if type(initlist) == type(self.data):
self.data[:] = initlist
elif isinstance(initlist, UserList):
self.data[:] = initlist.data[:]
else:
self.data = list(initlist)
def __repr__(self): return repr(self.data)
def __lt__(self, other): return self.data < self.__cast(other)
def __le__(self, other): return self.data <= self.__cast(other)
def __eq__(self, other): return self.data == self.__cast(other)
def __ne__(self, other): return self.data != self.__cast(other)
def __gt__(self, other): return self.data > self.__cast(other)
def __ge__(self, other): return self.data >= self.__cast(other)
def __cast(self, other):
if isinstance(other, UserList): return other.data
else: return other
def __cmp__(self, other):
return cmp(self.data, self.__cast(other))
__hash__ = None # Mutable sequence, so not hashable
def __contains__(self, item): return item in self.data
def __len__(self): return len(self.data)
def __getitem__(self, i): return self.data[i]
def __setitem__(self, i, item): self.data[i] = item
def __delitem__(self, i): del self.data[i]
def __getslice__(self, i, j):
i = max(i, 0); j = max(j, 0)
return self.__class__(self.data[i:j])
def __setslice__(self, i, j, other):
i = max(i, 0); j = max(j, 0)
if isinstance(other, UserList):
self.data[i:j] = other.data
elif isinstance(other, type(self.data)):
self.data[i:j] = other
else:
self.data[i:j] = list(other)
def __delslice__(self, i, j):
i = max(i, 0); j = max(j, 0)
del self.data[i:j]
def __add__(self, other):
if isinstance(other, UserList):
return self.__class__(self.data + other.data)
elif isinstance(other, type(self.data)):
return self.__class__(self.data + other)
else:
return self.__class__(self.data + list(other))
def __radd__(self, other):
if isinstance(other, UserList):
return self.__class__(other.data + self.data)
elif isinstance(other, type(self.data)):
return self.__class__(other + self.data)
else:
return self.__class__(list(other) + self.data)
def __iadd__(self, other):
if isinstance(other, UserList):
self.data += other.data
elif isinstance(other, type(self.data)):
self.data += other
else:
self.data += list(other)
return self
def __mul__(self, n):
return self.__class__(self.data*n)
__rmul__ = __mul__
def __imul__(self, n):
self.data *= n
return self
def append(self, item): self.data.append(item)
def insert(self, i, item): self.data.insert(i, item)
def pop(self, i=-1): return self.data.pop(i)
def remove(self, item): self.data.remove(item)
def count(self, item): return self.data.count(item)
def index(self, item, *args): return self.data.index(item, *args)
def reverse(self): self.data.reverse()
def sort(self, *args, **kwds): self.data.sort(*args, **kwds)
def extend(self, other):
if isinstance(other, UserList):
self.data.extend(other.data)
else:
self.data.extend(other)

228
src/python/stdlib/UserString.py Executable file
View File

@@ -0,0 +1,228 @@
#!/usr/bin/env python
## vim:ts=4:et:nowrap
"""A user-defined wrapper around string objects
Note: string objects have grown methods in Python 1.6
This module requires Python 1.6 or later.
"""
import sys
import collections
__all__ = ["UserString","MutableString"]
class UserString(collections.Sequence):
def __init__(self, seq):
if isinstance(seq, basestring):
self.data = seq
elif isinstance(seq, UserString):
self.data = seq.data[:]
else:
self.data = str(seq)
def __str__(self): return str(self.data)
def __repr__(self): return repr(self.data)
def __int__(self): return int(self.data)
def __long__(self): return long(self.data)
def __float__(self): return float(self.data)
def __complex__(self): return complex(self.data)
def __hash__(self): return hash(self.data)
def __cmp__(self, string):
if isinstance(string, UserString):
return cmp(self.data, string.data)
else:
return cmp(self.data, string)
def __contains__(self, char):
return char in self.data
def __len__(self): return len(self.data)
def __getitem__(self, index): return self.__class__(self.data[index])
def __getslice__(self, start, end):
start = max(start, 0); end = max(end, 0)
return self.__class__(self.data[start:end])
def __add__(self, other):
if isinstance(other, UserString):
return self.__class__(self.data + other.data)
elif isinstance(other, basestring):
return self.__class__(self.data + other)
else:
return self.__class__(self.data + str(other))
def __radd__(self, other):
if isinstance(other, basestring):
return self.__class__(other + self.data)
else:
return self.__class__(str(other) + self.data)
def __mul__(self, n):
return self.__class__(self.data*n)
__rmul__ = __mul__
def __mod__(self, args):
return self.__class__(self.data % args)
# the following methods are defined in alphabetical order:
def capitalize(self): return self.__class__(self.data.capitalize())
def center(self, width, *args):
return self.__class__(self.data.center(width, *args))
def count(self, sub, start=0, end=sys.maxint):
return self.data.count(sub, start, end)
def decode(self, encoding=None, errors=None): # XXX improve this?
if encoding:
if errors:
return self.__class__(self.data.decode(encoding, errors))
else:
return self.__class__(self.data.decode(encoding))
else:
return self.__class__(self.data.decode())
def encode(self, encoding=None, errors=None): # XXX improve this?
if encoding:
if errors:
return self.__class__(self.data.encode(encoding, errors))
else:
return self.__class__(self.data.encode(encoding))
else:
return self.__class__(self.data.encode())
def endswith(self, suffix, start=0, end=sys.maxint):
return self.data.endswith(suffix, start, end)
def expandtabs(self, tabsize=8):
return self.__class__(self.data.expandtabs(tabsize))
def find(self, sub, start=0, end=sys.maxint):
return self.data.find(sub, start, end)
def index(self, sub, start=0, end=sys.maxint):
return self.data.index(sub, start, end)
def isalpha(self): return self.data.isalpha()
def isalnum(self): return self.data.isalnum()
def isdecimal(self): return self.data.isdecimal()
def isdigit(self): return self.data.isdigit()
def islower(self): return self.data.islower()
def isnumeric(self): return self.data.isnumeric()
def isspace(self): return self.data.isspace()
def istitle(self): return self.data.istitle()
def isupper(self): return self.data.isupper()
def join(self, seq): return self.data.join(seq)
def ljust(self, width, *args):
return self.__class__(self.data.ljust(width, *args))
def lower(self): return self.__class__(self.data.lower())
def lstrip(self, chars=None): return self.__class__(self.data.lstrip(chars))
def partition(self, sep):
return self.data.partition(sep)
def replace(self, old, new, maxsplit=-1):
return self.__class__(self.data.replace(old, new, maxsplit))
def rfind(self, sub, start=0, end=sys.maxint):
return self.data.rfind(sub, start, end)
def rindex(self, sub, start=0, end=sys.maxint):
return self.data.rindex(sub, start, end)
def rjust(self, width, *args):
return self.__class__(self.data.rjust(width, *args))
def rpartition(self, sep):
return self.data.rpartition(sep)
def rstrip(self, chars=None): return self.__class__(self.data.rstrip(chars))
def split(self, sep=None, maxsplit=-1):
return self.data.split(sep, maxsplit)
def rsplit(self, sep=None, maxsplit=-1):
return self.data.rsplit(sep, maxsplit)
def splitlines(self, keepends=0): return self.data.splitlines(keepends)
def startswith(self, prefix, start=0, end=sys.maxint):
return self.data.startswith(prefix, start, end)
def strip(self, chars=None): return self.__class__(self.data.strip(chars))
def swapcase(self): return self.__class__(self.data.swapcase())
def title(self): return self.__class__(self.data.title())
def translate(self, *args):
return self.__class__(self.data.translate(*args))
def upper(self): return self.__class__(self.data.upper())
def zfill(self, width): return self.__class__(self.data.zfill(width))
class MutableString(UserString, collections.MutableSequence):
"""mutable string objects
Python strings are immutable objects. This has the advantage, that
strings may be used as dictionary keys. If this property isn't needed
and you insist on changing string values in place instead, you may cheat
and use MutableString.
But the purpose of this class is an educational one: to prevent
people from inventing their own mutable string class derived
from UserString and than forget thereby to remove (override) the
__hash__ method inherited from UserString. This would lead to
errors that would be very hard to track down.
A faster and better solution is to rewrite your program using lists."""
def __init__(self, string=""):
from warnings import warnpy3k
warnpy3k('the class UserString.MutableString has been removed in '
'Python 3.0', stacklevel=2)
self.data = string
# We inherit object.__hash__, so we must deny this explicitly
__hash__ = None
def __setitem__(self, index, sub):
if isinstance(index, slice):
if isinstance(sub, UserString):
sub = sub.data
elif not isinstance(sub, basestring):
sub = str(sub)
start, stop, step = index.indices(len(self.data))
if step == -1:
start, stop = stop+1, start+1
sub = sub[::-1]
elif step != 1:
# XXX(twouters): I guess we should be reimplementing
# the extended slice assignment/deletion algorithm here...
raise TypeError, "invalid step in slicing assignment"
start = min(start, stop)
self.data = self.data[:start] + sub + self.data[stop:]
else:
if index < 0:
index += len(self.data)
if index < 0 or index >= len(self.data): raise IndexError
self.data = self.data[:index] + sub + self.data[index+1:]
def __delitem__(self, index):
if isinstance(index, slice):
start, stop, step = index.indices(len(self.data))
if step == -1:
start, stop = stop+1, start+1
elif step != 1:
# XXX(twouters): see same block in __setitem__
raise TypeError, "invalid step in slicing deletion"
start = min(start, stop)
self.data = self.data[:start] + self.data[stop:]
else:
if index < 0:
index += len(self.data)
if index < 0 or index >= len(self.data): raise IndexError
self.data = self.data[:index] + self.data[index+1:]
def __setslice__(self, start, end, sub):
start = max(start, 0); end = max(end, 0)
if isinstance(sub, UserString):
self.data = self.data[:start]+sub.data+self.data[end:]
elif isinstance(sub, basestring):
self.data = self.data[:start]+sub+self.data[end:]
else:
self.data = self.data[:start]+str(sub)+self.data[end:]
def __delslice__(self, start, end):
start = max(start, 0); end = max(end, 0)
self.data = self.data[:start] + self.data[end:]
def immutable(self):
return UserString(self.data)
def __iadd__(self, other):
if isinstance(other, UserString):
self.data += other.data
elif isinstance(other, basestring):
self.data += other
else:
self.data += str(other)
return self
def __imul__(self, n):
self.data *= n
return self
def insert(self, index, value):
self[index:index] = value
if __name__ == "__main__":
# execute the regression test to stdout, if called as a script:
import os
called_in_dir, called_as = os.path.split(sys.argv[0])
called_as, py = os.path.splitext(called_as)
if '-q' in sys.argv:
from test import test_support
test_support.verbose = 0
__import__('test.test_' + called_as.lower())

View File

@@ -0,0 +1,170 @@
"""Load / save to libwww-perl (LWP) format files.
Actually, the format is slightly extended from that used by LWP's
(libwww-perl's) HTTP::Cookies, to avoid losing some RFC 2965 information
not recorded by LWP.
It uses the version string "2.0", though really there isn't an LWP Cookies
2.0 format. This indicates that there is extra information in here
(domain_dot and # port_spec) while still being compatible with
libwww-perl, I hope.
"""
import time, re
from cookielib import (_warn_unhandled_exception, FileCookieJar, LoadError,
Cookie, MISSING_FILENAME_TEXT,
join_header_words, split_header_words,
iso2time, time2isoz)
def lwp_cookie_str(cookie):
"""Return string representation of Cookie in an the LWP cookie file format.
Actually, the format is extended a bit -- see module docstring.
"""
h = [(cookie.name, cookie.value),
("path", cookie.path),
("domain", cookie.domain)]
if cookie.port is not None: h.append(("port", cookie.port))
if cookie.path_specified: h.append(("path_spec", None))
if cookie.port_specified: h.append(("port_spec", None))
if cookie.domain_initial_dot: h.append(("domain_dot", None))
if cookie.secure: h.append(("secure", None))
if cookie.expires: h.append(("expires",
time2isoz(float(cookie.expires))))
if cookie.discard: h.append(("discard", None))
if cookie.comment: h.append(("comment", cookie.comment))
if cookie.comment_url: h.append(("commenturl", cookie.comment_url))
keys = cookie._rest.keys()
keys.sort()
for k in keys:
h.append((k, str(cookie._rest[k])))
h.append(("version", str(cookie.version)))
return join_header_words([h])
class LWPCookieJar(FileCookieJar):
"""
The LWPCookieJar saves a sequence of"Set-Cookie3" lines.
"Set-Cookie3" is the format used by the libwww-perl libary, not known
to be compatible with any browser, but which is easy to read and
doesn't lose information about RFC 2965 cookies.
Additional methods
as_lwp_str(ignore_discard=True, ignore_expired=True)
"""
def as_lwp_str(self, ignore_discard=True, ignore_expires=True):
"""Return cookies as a string of "\n"-separated "Set-Cookie3" headers.
ignore_discard and ignore_expires: see docstring for FileCookieJar.save
"""
now = time.time()
r = []
for cookie in self:
if not ignore_discard and cookie.discard:
continue
if not ignore_expires and cookie.is_expired(now):
continue
r.append("Set-Cookie3: %s" % lwp_cookie_str(cookie))
return "\n".join(r+[""])
def save(self, filename=None, ignore_discard=False, ignore_expires=False):
if filename is None:
if self.filename is not None: filename = self.filename
else: raise ValueError(MISSING_FILENAME_TEXT)
f = open(filename, "w")
try:
# There really isn't an LWP Cookies 2.0 format, but this indicates
# that there is extra information in here (domain_dot and
# port_spec) while still being compatible with libwww-perl, I hope.
f.write("#LWP-Cookies-2.0\n")
f.write(self.as_lwp_str(ignore_discard, ignore_expires))
finally:
f.close()
def _really_load(self, f, filename, ignore_discard, ignore_expires):
magic = f.readline()
if not re.search(self.magic_re, magic):
msg = ("%r does not look like a Set-Cookie3 (LWP) format "
"file" % filename)
raise LoadError(msg)
now = time.time()
header = "Set-Cookie3:"
boolean_attrs = ("port_spec", "path_spec", "domain_dot",
"secure", "discard")
value_attrs = ("version",
"port", "path", "domain",
"expires",
"comment", "commenturl")
try:
while 1:
line = f.readline()
if line == "": break
if not line.startswith(header):
continue
line = line[len(header):].strip()
for data in split_header_words([line]):
name, value = data[0]
standard = {}
rest = {}
for k in boolean_attrs:
standard[k] = False
for k, v in data[1:]:
if k is not None:
lc = k.lower()
else:
lc = None
# don't lose case distinction for unknown fields
if (lc in value_attrs) or (lc in boolean_attrs):
k = lc
if k in boolean_attrs:
if v is None: v = True
standard[k] = v
elif k in value_attrs:
standard[k] = v
else:
rest[k] = v
h = standard.get
expires = h("expires")
discard = h("discard")
if expires is not None:
expires = iso2time(expires)
if expires is None:
discard = True
domain = h("domain")
domain_specified = domain.startswith(".")
c = Cookie(h("version"), name, value,
h("port"), h("port_spec"),
domain, domain_specified, h("domain_dot"),
h("path"), h("path_spec"),
h("secure"),
expires,
discard,
h("comment"),
h("commenturl"),
rest)
if not ignore_discard and c.discard:
continue
if not ignore_expires and c.is_expired(now):
continue
self.set_cookie(c)
except IOError:
raise
except Exception:
_warn_unhandled_exception()
raise LoadError("invalid Set-Cookie3 format file %r: %r" %
(filename, line))

View File

@@ -0,0 +1,149 @@
"""Mozilla / Netscape cookie loading / saving."""
import re, time
from cookielib import (_warn_unhandled_exception, FileCookieJar, LoadError,
Cookie, MISSING_FILENAME_TEXT)
class MozillaCookieJar(FileCookieJar):
"""
WARNING: you may want to backup your browser's cookies file if you use
this class to save cookies. I *think* it works, but there have been
bugs in the past!
This class differs from CookieJar only in the format it uses to save and
load cookies to and from a file. This class uses the Mozilla/Netscape
`cookies.txt' format. lynx uses this file format, too.
Don't expect cookies saved while the browser is running to be noticed by
the browser (in fact, Mozilla on unix will overwrite your saved cookies if
you change them on disk while it's running; on Windows, you probably can't
save at all while the browser is running).
Note that the Mozilla/Netscape format will downgrade RFC2965 cookies to
Netscape cookies on saving.
In particular, the cookie version and port number information is lost,
together with information about whether or not Path, Port and Discard were
specified by the Set-Cookie2 (or Set-Cookie) header, and whether or not the
domain as set in the HTTP header started with a dot (yes, I'm aware some
domains in Netscape files start with a dot and some don't -- trust me, you
really don't want to know any more about this).
Note that though Mozilla and Netscape use the same format, they use
slightly different headers. The class saves cookies using the Netscape
header by default (Mozilla can cope with that).
"""
magic_re = "#( Netscape)? HTTP Cookie File"
header = """\
# Netscape HTTP Cookie File
# http://www.netscape.com/newsref/std/cookie_spec.html
# This is a generated file! Do not edit.
"""
def _really_load(self, f, filename, ignore_discard, ignore_expires):
now = time.time()
magic = f.readline()
if not re.search(self.magic_re, magic):
f.close()
raise LoadError(
"%r does not look like a Netscape format cookies file" %
filename)
try:
while 1:
line = f.readline()
if line == "": break
# last field may be absent, so keep any trailing tab
if line.endswith("\n"): line = line[:-1]
# skip comments and blank lines XXX what is $ for?
if (line.strip().startswith(("#", "$")) or
line.strip() == ""):
continue
domain, domain_specified, path, secure, expires, name, value = \
line.split("\t")
secure = (secure == "TRUE")
domain_specified = (domain_specified == "TRUE")
if name == "":
# cookies.txt regards 'Set-Cookie: foo' as a cookie
# with no name, whereas cookielib regards it as a
# cookie with no value.
name = value
value = None
initial_dot = domain.startswith(".")
assert domain_specified == initial_dot
discard = False
if expires == "":
expires = None
discard = True
# assume path_specified is false
c = Cookie(0, name, value,
None, False,
domain, domain_specified, initial_dot,
path, False,
secure,
expires,
discard,
None,
None,
{})
if not ignore_discard and c.discard:
continue
if not ignore_expires and c.is_expired(now):
continue
self.set_cookie(c)
except IOError:
raise
except Exception:
_warn_unhandled_exception()
raise LoadError("invalid Netscape format cookies file %r: %r" %
(filename, line))
def save(self, filename=None, ignore_discard=False, ignore_expires=False):
if filename is None:
if self.filename is not None: filename = self.filename
else: raise ValueError(MISSING_FILENAME_TEXT)
f = open(filename, "w")
try:
f.write(self.header)
now = time.time()
for cookie in self:
if not ignore_discard and cookie.discard:
continue
if not ignore_expires and cookie.is_expired(now):
continue
if cookie.secure: secure = "TRUE"
else: secure = "FALSE"
if cookie.domain.startswith("."): initial_dot = "TRUE"
else: initial_dot = "FALSE"
if cookie.expires is not None:
expires = str(cookie.expires)
else:
expires = ""
if cookie.value is None:
# cookies.txt regards 'Set-Cookie: foo' as a cookie
# with no name, whereas cookielib regards it as a
# cookie with no value.
name = ""
value = cookie.name
else:
name = cookie.name
value = cookie.value
f.write(
"\t".join([cookie.domain, initial_dot, cookie.path,
secure, expires, name, value])+
"\n")
finally:
f.close()

View File

@@ -0,0 +1,128 @@
"""Record of phased-in incompatible language changes.
Each line is of the form:
FeatureName = "_Feature(" OptionalRelease "," MandatoryRelease ","
CompilerFlag ")"
where, normally, OptionalRelease < MandatoryRelease, and both are 5-tuples
of the same form as sys.version_info:
(PY_MAJOR_VERSION, # the 2 in 2.1.0a3; an int
PY_MINOR_VERSION, # the 1; an int
PY_MICRO_VERSION, # the 0; an int
PY_RELEASE_LEVEL, # "alpha", "beta", "candidate" or "final"; string
PY_RELEASE_SERIAL # the 3; an int
)
OptionalRelease records the first release in which
from __future__ import FeatureName
was accepted.
In the case of MandatoryReleases that have not yet occurred,
MandatoryRelease predicts the release in which the feature will become part
of the language.
Else MandatoryRelease records when the feature became part of the language;
in releases at or after that, modules no longer need
from __future__ import FeatureName
to use the feature in question, but may continue to use such imports.
MandatoryRelease may also be None, meaning that a planned feature got
dropped.
Instances of class _Feature have two corresponding methods,
.getOptionalRelease() and .getMandatoryRelease().
CompilerFlag is the (bitfield) flag that should be passed in the fourth
argument to the builtin function compile() to enable the feature in
dynamically compiled code. This flag is stored in the .compiler_flag
attribute on _Future instances. These values must match the appropriate
#defines of CO_xxx flags in Include/compile.h.
No feature line is ever to be deleted from this file.
"""
all_feature_names = [
"nested_scopes",
"generators",
"division",
"absolute_import",
"with_statement",
"print_function",
"unicode_literals",
]
__all__ = ["all_feature_names"] + all_feature_names
# The CO_xxx symbols are defined here under the same names used by
# compile.h, so that an editor search will find them here. However,
# they're not exported in __all__, because they don't really belong to
# this module.
CO_NESTED = 0x0010 # nested_scopes
CO_GENERATOR_ALLOWED = 0 # generators (obsolete, was 0x1000)
CO_FUTURE_DIVISION = 0x2000 # division
CO_FUTURE_ABSOLUTE_IMPORT = 0x4000 # perform absolute imports by default
CO_FUTURE_WITH_STATEMENT = 0x8000 # with statement
CO_FUTURE_PRINT_FUNCTION = 0x10000 # print function
CO_FUTURE_UNICODE_LITERALS = 0x20000 # unicode string literals
class _Feature:
def __init__(self, optionalRelease, mandatoryRelease, compiler_flag):
self.optional = optionalRelease
self.mandatory = mandatoryRelease
self.compiler_flag = compiler_flag
def getOptionalRelease(self):
"""Return first release in which this feature was recognized.
This is a 5-tuple, of the same form as sys.version_info.
"""
return self.optional
def getMandatoryRelease(self):
"""Return release in which this feature will become mandatory.
This is a 5-tuple, of the same form as sys.version_info, or, if
the feature was dropped, is None.
"""
return self.mandatory
def __repr__(self):
return "_Feature" + repr((self.optional,
self.mandatory,
self.compiler_flag))
nested_scopes = _Feature((2, 1, 0, "beta", 1),
(2, 2, 0, "alpha", 0),
CO_NESTED)
generators = _Feature((2, 2, 0, "alpha", 1),
(2, 3, 0, "final", 0),
CO_GENERATOR_ALLOWED)
division = _Feature((2, 2, 0, "alpha", 2),
(3, 0, 0, "alpha", 0),
CO_FUTURE_DIVISION)
absolute_import = _Feature((2, 5, 0, "alpha", 1),
(2, 7, 0, "alpha", 0),
CO_FUTURE_ABSOLUTE_IMPORT)
with_statement = _Feature((2, 5, 0, "alpha", 1),
(2, 6, 0, "alpha", 0),
CO_FUTURE_WITH_STATEMENT)
print_function = _Feature((2, 6, 0, "alpha", 2),
(3, 0, 0, "alpha", 0),
CO_FUTURE_PRINT_FUNCTION)
unicode_literals = _Feature((2, 6, 0, "alpha", 2),
(3, 0, 0, "alpha", 0),
CO_FUTURE_UNICODE_LITERALS)

View File

@@ -0,0 +1 @@
# This file exists as a helper for the test.test_frozen module.

View File

@@ -0,0 +1,601 @@
# Copyright 2007 Google, Inc. All Rights Reserved.
# Licensed to PSF under a Contributor Agreement.
"""Abstract Base Classes (ABCs) for collections, according to PEP 3119.
DON'T USE THIS MODULE DIRECTLY! The classes here should be imported
via collections; they are defined here only to alleviate certain
bootstrapping issues. Unit tests are in test_collections.
"""
from abc import ABCMeta, abstractmethod
import sys
__all__ = ["Hashable", "Iterable", "Iterator",
"Sized", "Container", "Callable",
"Set", "MutableSet",
"Mapping", "MutableMapping",
"MappingView", "KeysView", "ItemsView", "ValuesView",
"Sequence", "MutableSequence",
]
### ONE-TRICK PONIES ###
def _hasattr(C, attr):
try:
return any(attr in B.__dict__ for B in C.__mro__)
except AttributeError:
# Old-style class
return hasattr(C, attr)
class Hashable:
__metaclass__ = ABCMeta
@abstractmethod
def __hash__(self):
return 0
@classmethod
def __subclasshook__(cls, C):
if cls is Hashable:
try:
for B in C.__mro__:
if "__hash__" in B.__dict__:
if B.__dict__["__hash__"]:
return True
break
except AttributeError:
# Old-style class
if getattr(C, "__hash__", None):
return True
return NotImplemented
class Iterable:
__metaclass__ = ABCMeta
@abstractmethod
def __iter__(self):
while False:
yield None
@classmethod
def __subclasshook__(cls, C):
if cls is Iterable:
if _hasattr(C, "__iter__"):
return True
return NotImplemented
Iterable.register(str)
class Iterator(Iterable):
@abstractmethod
def next(self):
raise StopIteration
def __iter__(self):
return self
@classmethod
def __subclasshook__(cls, C):
if cls is Iterator:
if _hasattr(C, "next"):
return True
return NotImplemented
class Sized:
__metaclass__ = ABCMeta
@abstractmethod
def __len__(self):
return 0
@classmethod
def __subclasshook__(cls, C):
if cls is Sized:
if _hasattr(C, "__len__"):
return True
return NotImplemented
class Container:
__metaclass__ = ABCMeta
@abstractmethod
def __contains__(self, x):
return False
@classmethod
def __subclasshook__(cls, C):
if cls is Container:
if _hasattr(C, "__contains__"):
return True
return NotImplemented
class Callable:
__metaclass__ = ABCMeta
@abstractmethod
def __call__(self, *args, **kwds):
return False
@classmethod
def __subclasshook__(cls, C):
if cls is Callable:
if _hasattr(C, "__call__"):
return True
return NotImplemented
### SETS ###
class Set(Sized, Iterable, Container):
"""A set is a finite, iterable container.
This class provides concrete generic implementations of all
methods except for __contains__, __iter__ and __len__.
To override the comparisons (presumably for speed, as the
semantics are fixed), all you have to do is redefine __le__ and
then the other operations will automatically follow suit.
"""
def __le__(self, other):
if not isinstance(other, Set):
return NotImplemented
if len(self) > len(other):
return False
for elem in self:
if elem not in other:
return False
return True
def __lt__(self, other):
if not isinstance(other, Set):
return NotImplemented
return len(self) < len(other) and self.__le__(other)
def __gt__(self, other):
if not isinstance(other, Set):
return NotImplemented
return other < self
def __ge__(self, other):
if not isinstance(other, Set):
return NotImplemented
return other <= self
def __eq__(self, other):
if not isinstance(other, Set):
return NotImplemented
return len(self) == len(other) and self.__le__(other)
def __ne__(self, other):
return not (self == other)
@classmethod
def _from_iterable(cls, it):
'''Construct an instance of the class from any iterable input.
Must override this method if the class constructor signature
does not accept an iterable for an input.
'''
return cls(it)
def __and__(self, other):
if not isinstance(other, Iterable):
return NotImplemented
return self._from_iterable(value for value in other if value in self)
def isdisjoint(self, other):
for value in other:
if value in self:
return False
return True
def __or__(self, other):
if not isinstance(other, Iterable):
return NotImplemented
chain = (e for s in (self, other) for e in s)
return self._from_iterable(chain)
def __sub__(self, other):
if not isinstance(other, Set):
if not isinstance(other, Iterable):
return NotImplemented
other = self._from_iterable(other)
return self._from_iterable(value for value in self
if value not in other)
def __xor__(self, other):
if not isinstance(other, Set):
if not isinstance(other, Iterable):
return NotImplemented
other = self._from_iterable(other)
return (self - other) | (other - self)
# Sets are not hashable by default, but subclasses can change this
__hash__ = None
def _hash(self):
"""Compute the hash value of a set.
Note that we don't define __hash__: not all sets are hashable.
But if you define a hashable set type, its __hash__ should
call this function.
This must be compatible __eq__.
All sets ought to compare equal if they contain the same
elements, regardless of how they are implemented, and
regardless of the order of the elements; so there's not much
freedom for __eq__ or __hash__. We match the algorithm used
by the built-in frozenset type.
"""
MAX = sys.maxint
MASK = 2 * MAX + 1
n = len(self)
h = 1927868237 * (n + 1)
h &= MASK
for x in self:
hx = hash(x)
h ^= (hx ^ (hx << 16) ^ 89869747) * 3644798167
h &= MASK
h = h * 69069 + 907133923
h &= MASK
if h > MAX:
h -= MASK + 1
if h == -1:
h = 590923713
return h
Set.register(frozenset)
class MutableSet(Set):
@abstractmethod
def add(self, value):
"""Add an element."""
raise NotImplementedError
@abstractmethod
def discard(self, value):
"""Remove an element. Do not raise an exception if absent."""
raise NotImplementedError
def remove(self, value):
"""Remove an element. If not a member, raise a KeyError."""
if value not in self:
raise KeyError(value)
self.discard(value)
def pop(self):
"""Return the popped value. Raise KeyError if empty."""
it = iter(self)
try:
value = next(it)
except StopIteration:
raise KeyError
self.discard(value)
return value
def clear(self):
"""This is slow (creates N new iterators!) but effective."""
try:
while True:
self.pop()
except KeyError:
pass
def __ior__(self, it):
for value in it:
self.add(value)
return self
def __iand__(self, it):
for value in (self - it):
self.discard(value)
return self
def __ixor__(self, it):
if it is self:
self.clear()
else:
if not isinstance(it, Set):
it = self._from_iterable(it)
for value in it:
if value in self:
self.discard(value)
else:
self.add(value)
return self
def __isub__(self, it):
if it is self:
self.clear()
else:
for value in it:
self.discard(value)
return self
MutableSet.register(set)
### MAPPINGS ###
class Mapping(Sized, Iterable, Container):
@abstractmethod
def __getitem__(self, key):
raise KeyError
def get(self, key, default=None):
try:
return self[key]
except KeyError:
return default
def __contains__(self, key):
try:
self[key]
except KeyError:
return False
else:
return True
def iterkeys(self):
return iter(self)
def itervalues(self):
for key in self:
yield self[key]
def iteritems(self):
for key in self:
yield (key, self[key])
def keys(self):
return list(self)
def items(self):
return [(key, self[key]) for key in self]
def values(self):
return [self[key] for key in self]
# Mappings are not hashable by default, but subclasses can change this
__hash__ = None
def __eq__(self, other):
if not isinstance(other, Mapping):
return NotImplemented
return dict(self.items()) == dict(other.items())
def __ne__(self, other):
return not (self == other)
class MappingView(Sized):
def __init__(self, mapping):
self._mapping = mapping
def __len__(self):
return len(self._mapping)
def __repr__(self):
return '{0.__class__.__name__}({0._mapping!r})'.format(self)
class KeysView(MappingView, Set):
@classmethod
def _from_iterable(self, it):
return set(it)
def __contains__(self, key):
return key in self._mapping
def __iter__(self):
for key in self._mapping:
yield key
class ItemsView(MappingView, Set):
@classmethod
def _from_iterable(self, it):
return set(it)
def __contains__(self, item):
key, value = item
try:
v = self._mapping[key]
except KeyError:
return False
else:
return v == value
def __iter__(self):
for key in self._mapping:
yield (key, self._mapping[key])
class ValuesView(MappingView):
def __contains__(self, value):
for key in self._mapping:
if value == self._mapping[key]:
return True
return False
def __iter__(self):
for key in self._mapping:
yield self._mapping[key]
class MutableMapping(Mapping):
@abstractmethod
def __setitem__(self, key, value):
raise KeyError
@abstractmethod
def __delitem__(self, key):
raise KeyError
__marker = object()
def pop(self, key, default=__marker):
try:
value = self[key]
except KeyError:
if default is self.__marker:
raise
return default
else:
del self[key]
return value
def popitem(self):
try:
key = next(iter(self))
except StopIteration:
raise KeyError
value = self[key]
del self[key]
return key, value
def clear(self):
try:
while True:
self.popitem()
except KeyError:
pass
def update(*args, **kwds):
if len(args) > 2:
raise TypeError("update() takes at most 2 positional "
"arguments ({} given)".format(len(args)))
elif not args:
raise TypeError("update() takes at least 1 argument (0 given)")
self = args[0]
other = args[1] if len(args) >= 2 else ()
if isinstance(other, Mapping):
for key in other:
self[key] = other[key]
elif hasattr(other, "keys"):
for key in other.keys():
self[key] = other[key]
else:
for key, value in other:
self[key] = value
for key, value in kwds.items():
self[key] = value
def setdefault(self, key, default=None):
try:
return self[key]
except KeyError:
self[key] = default
return default
MutableMapping.register(dict)
### SEQUENCES ###
class Sequence(Sized, Iterable, Container):
"""All the operations on a read-only sequence.
Concrete subclasses must override __new__ or __init__,
__getitem__, and __len__.
"""
@abstractmethod
def __getitem__(self, index):
raise IndexError
def __iter__(self):
i = 0
try:
while True:
v = self[i]
yield v
i += 1
except IndexError:
return
def __contains__(self, value):
for v in self:
if v == value:
return True
return False
def __reversed__(self):
for i in reversed(range(len(self))):
yield self[i]
def index(self, value):
for i, v in enumerate(self):
if v == value:
return i
raise ValueError
def count(self, value):
return sum(1 for v in self if v == value)
Sequence.register(tuple)
Sequence.register(basestring)
Sequence.register(buffer)
Sequence.register(xrange)
class MutableSequence(Sequence):
@abstractmethod
def __setitem__(self, index, value):
raise IndexError
@abstractmethod
def __delitem__(self, index):
raise IndexError
@abstractmethod
def insert(self, index, value):
raise IndexError
def append(self, value):
self.insert(len(self), value)
def reverse(self):
n = len(self)
for i in range(n//2):
self[i], self[n-i-1] = self[n-i-1], self[i]
def extend(self, values):
for v in values:
self.append(v)
def pop(self, index=-1):
v = self[index]
del self[index]
return v
def remove(self, value):
del self[self.index(value)]
def __iadd__(self, values):
self.extend(values)
return self
MutableSequence.register(list)

1971
src/python/stdlib/_pyio.py Normal file

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,454 @@
"""Strptime-related classes and functions.
CLASSES:
LocaleTime -- Discovers and stores locale-specific time information
TimeRE -- Creates regexes for pattern matching a string of text containing
time information
FUNCTIONS:
_getlang -- Figure out what language is being used for the locale
strptime -- Calculates the time struct represented by the passed-in string
"""
import time
import locale
import calendar
from re import compile as re_compile
from re import IGNORECASE
from re import escape as re_escape
from datetime import date as datetime_date
try:
from thread import allocate_lock as _thread_allocate_lock
except:
from dummy_thread import allocate_lock as _thread_allocate_lock
__all__ = []
def _getlang():
# Figure out what the current language is set to.
return locale.getlocale(locale.LC_TIME)
class LocaleTime(object):
"""Stores and handles locale-specific information related to time.
ATTRIBUTES:
f_weekday -- full weekday names (7-item list)
a_weekday -- abbreviated weekday names (7-item list)
f_month -- full month names (13-item list; dummy value in [0], which
is added by code)
a_month -- abbreviated month names (13-item list, dummy value in
[0], which is added by code)
am_pm -- AM/PM representation (2-item list)
LC_date_time -- format string for date/time representation (string)
LC_date -- format string for date representation (string)
LC_time -- format string for time representation (string)
timezone -- daylight- and non-daylight-savings timezone representation
(2-item list of sets)
lang -- Language used by instance (2-item tuple)
"""
def __init__(self):
"""Set all attributes.
Order of methods called matters for dependency reasons.
The locale language is set at the offset and then checked again before
exiting. This is to make sure that the attributes were not set with a
mix of information from more than one locale. This would most likely
happen when using threads where one thread calls a locale-dependent
function while another thread changes the locale while the function in
the other thread is still running. Proper coding would call for
locks to prevent changing the locale while locale-dependent code is
running. The check here is done in case someone does not think about
doing this.
Only other possible issue is if someone changed the timezone and did
not call tz.tzset . That is an issue for the programmer, though,
since changing the timezone is worthless without that call.
"""
self.lang = _getlang()
self.__calc_weekday()
self.__calc_month()
self.__calc_am_pm()
self.__calc_timezone()
self.__calc_date_time()
if _getlang() != self.lang:
raise ValueError("locale changed during initialization")
def __pad(self, seq, front):
# Add '' to seq to either the front (is True), else the back.
seq = list(seq)
if front:
seq.insert(0, '')
else:
seq.append('')
return seq
def __calc_weekday(self):
# Set self.a_weekday and self.f_weekday using the calendar
# module.
a_weekday = [calendar.day_abbr[i].lower() for i in range(7)]
f_weekday = [calendar.day_name[i].lower() for i in range(7)]
self.a_weekday = a_weekday
self.f_weekday = f_weekday
def __calc_month(self):
# Set self.f_month and self.a_month using the calendar module.
a_month = [calendar.month_abbr[i].lower() for i in range(13)]
f_month = [calendar.month_name[i].lower() for i in range(13)]
self.a_month = a_month
self.f_month = f_month
def __calc_am_pm(self):
# Set self.am_pm by using time.strftime().
# The magic date (1999,3,17,hour,44,55,2,76,0) is not really that
# magical; just happened to have used it everywhere else where a
# static date was needed.
am_pm = []
for hour in (01,22):
time_tuple = time.struct_time((1999,3,17,hour,44,55,2,76,0))
am_pm.append(time.strftime("%p", time_tuple).lower())
self.am_pm = am_pm
def __calc_date_time(self):
# Set self.date_time, self.date, & self.time by using
# time.strftime().
# Use (1999,3,17,22,44,55,2,76,0) for magic date because the amount of
# overloaded numbers is minimized. The order in which searches for
# values within the format string is very important; it eliminates
# possible ambiguity for what something represents.
time_tuple = time.struct_time((1999,3,17,22,44,55,2,76,0))
date_time = [None, None, None]
date_time[0] = time.strftime("%c", time_tuple).lower()
date_time[1] = time.strftime("%x", time_tuple).lower()
date_time[2] = time.strftime("%X", time_tuple).lower()
replacement_pairs = [('%', '%%'), (self.f_weekday[2], '%A'),
(self.f_month[3], '%B'), (self.a_weekday[2], '%a'),
(self.a_month[3], '%b'), (self.am_pm[1], '%p'),
('1999', '%Y'), ('99', '%y'), ('22', '%H'),
('44', '%M'), ('55', '%S'), ('76', '%j'),
('17', '%d'), ('03', '%m'), ('3', '%m'),
# '3' needed for when no leading zero.
('2', '%w'), ('10', '%I')]
replacement_pairs.extend([(tz, "%Z") for tz_values in self.timezone
for tz in tz_values])
for offset,directive in ((0,'%c'), (1,'%x'), (2,'%X')):
current_format = date_time[offset]
for old, new in replacement_pairs:
# Must deal with possible lack of locale info
# manifesting itself as the empty string (e.g., Swedish's
# lack of AM/PM info) or a platform returning a tuple of empty
# strings (e.g., MacOS 9 having timezone as ('','')).
if old:
current_format = current_format.replace(old, new)
# If %W is used, then Sunday, 2005-01-03 will fall on week 0 since
# 2005-01-03 occurs before the first Monday of the year. Otherwise
# %U is used.
time_tuple = time.struct_time((1999,1,3,1,1,1,6,3,0))
if '00' in time.strftime(directive, time_tuple):
U_W = '%W'
else:
U_W = '%U'
date_time[offset] = current_format.replace('11', U_W)
self.LC_date_time = date_time[0]
self.LC_date = date_time[1]
self.LC_time = date_time[2]
def __calc_timezone(self):
# Set self.timezone by using time.tzname.
# Do not worry about possibility of time.tzname[0] == timetzname[1]
# and time.daylight; handle that in strptime .
try:
time.tzset()
except AttributeError:
pass
no_saving = frozenset(["utc", "gmt", time.tzname[0].lower()])
if time.daylight:
has_saving = frozenset([time.tzname[1].lower()])
else:
has_saving = frozenset()
self.timezone = (no_saving, has_saving)
class TimeRE(dict):
"""Handle conversion from format directives to regexes."""
def __init__(self, locale_time=None):
"""Create keys/values.
Order of execution is important for dependency reasons.
"""
if locale_time:
self.locale_time = locale_time
else:
self.locale_time = LocaleTime()
base = super(TimeRE, self)
base.__init__({
# The " \d" part of the regex is to make %c from ANSI C work
'd': r"(?P<d>3[0-1]|[1-2]\d|0[1-9]|[1-9]| [1-9])",
'f': r"(?P<f>[0-9]{1,6})",
'H': r"(?P<H>2[0-3]|[0-1]\d|\d)",
'I': r"(?P<I>1[0-2]|0[1-9]|[1-9])",
'j': r"(?P<j>36[0-6]|3[0-5]\d|[1-2]\d\d|0[1-9]\d|00[1-9]|[1-9]\d|0[1-9]|[1-9])",
'm': r"(?P<m>1[0-2]|0[1-9]|[1-9])",
'M': r"(?P<M>[0-5]\d|\d)",
'S': r"(?P<S>6[0-1]|[0-5]\d|\d)",
'U': r"(?P<U>5[0-3]|[0-4]\d|\d)",
'w': r"(?P<w>[0-6])",
# W is set below by using 'U'
'y': r"(?P<y>\d\d)",
#XXX: Does 'Y' need to worry about having less or more than
# 4 digits?
'Y': r"(?P<Y>\d\d\d\d)",
'A': self.__seqToRE(self.locale_time.f_weekday, 'A'),
'a': self.__seqToRE(self.locale_time.a_weekday, 'a'),
'B': self.__seqToRE(self.locale_time.f_month[1:], 'B'),
'b': self.__seqToRE(self.locale_time.a_month[1:], 'b'),
'p': self.__seqToRE(self.locale_time.am_pm, 'p'),
'Z': self.__seqToRE((tz for tz_names in self.locale_time.timezone
for tz in tz_names),
'Z'),
'%': '%'})
base.__setitem__('W', base.__getitem__('U').replace('U', 'W'))
base.__setitem__('c', self.pattern(self.locale_time.LC_date_time))
base.__setitem__('x', self.pattern(self.locale_time.LC_date))
base.__setitem__('X', self.pattern(self.locale_time.LC_time))
def __seqToRE(self, to_convert, directive):
"""Convert a list to a regex string for matching a directive.
Want possible matching values to be from longest to shortest. This
prevents the possibility of a match occuring for a value that also
a substring of a larger value that should have matched (e.g., 'abc'
matching when 'abcdef' should have been the match).
"""
to_convert = sorted(to_convert, key=len, reverse=True)
for value in to_convert:
if value != '':
break
else:
return ''
regex = '|'.join(re_escape(stuff) for stuff in to_convert)
regex = '(?P<%s>%s' % (directive, regex)
return '%s)' % regex
def pattern(self, format):
"""Return regex pattern for the format string.
Need to make sure that any characters that might be interpreted as
regex syntax are escaped.
"""
processed_format = ''
# The sub() call escapes all characters that might be misconstrued
# as regex syntax. Cannot use re.escape since we have to deal with
# format directives (%m, etc.).
regex_chars = re_compile(r"([\\.^$*+?\(\){}\[\]|])")
format = regex_chars.sub(r"\\\1", format)
whitespace_replacement = re_compile('\s+')
format = whitespace_replacement.sub('\s+', format)
while '%' in format:
directive_index = format.index('%')+1
processed_format = "%s%s%s" % (processed_format,
format[:directive_index-1],
self[format[directive_index]])
format = format[directive_index+1:]
return "%s%s" % (processed_format, format)
def compile(self, format):
"""Return a compiled re object for the format string."""
return re_compile(self.pattern(format), IGNORECASE)
_cache_lock = _thread_allocate_lock()
# DO NOT modify _TimeRE_cache or _regex_cache without acquiring the cache lock
# first!
_TimeRE_cache = TimeRE()
_CACHE_MAX_SIZE = 5 # Max number of regexes stored in _regex_cache
_regex_cache = {}
def _calc_julian_from_U_or_W(year, week_of_year, day_of_week, week_starts_Mon):
"""Calculate the Julian day based on the year, week of the year, and day of
the week, with week_start_day representing whether the week of the year
assumes the week starts on Sunday or Monday (6 or 0)."""
first_weekday = datetime_date(year, 1, 1).weekday()
# If we are dealing with the %U directive (week starts on Sunday), it's
# easier to just shift the view to Sunday being the first day of the
# week.
if not week_starts_Mon:
first_weekday = (first_weekday + 1) % 7
day_of_week = (day_of_week + 1) % 7
# Need to watch out for a week 0 (when the first day of the year is not
# the same as that specified by %U or %W).
week_0_length = (7 - first_weekday) % 7
if week_of_year == 0:
return 1 + day_of_week - first_weekday
else:
days_to_week = week_0_length + (7 * (week_of_year - 1))
return 1 + days_to_week + day_of_week
def _strptime(data_string, format="%a %b %d %H:%M:%S %Y"):
"""Return a time struct based on the input string and the format string."""
global _TimeRE_cache, _regex_cache
with _cache_lock:
if _getlang() != _TimeRE_cache.locale_time.lang:
_TimeRE_cache = TimeRE()
_regex_cache.clear()
if len(_regex_cache) > _CACHE_MAX_SIZE:
_regex_cache.clear()
locale_time = _TimeRE_cache.locale_time
format_regex = _regex_cache.get(format)
if not format_regex:
try:
format_regex = _TimeRE_cache.compile(format)
# KeyError raised when a bad format is found; can be specified as
# \\, in which case it was a stray % but with a space after it
except KeyError, err:
bad_directive = err.args[0]
if bad_directive == "\\":
bad_directive = "%"
del err
raise ValueError("'%s' is a bad directive in format '%s'" %
(bad_directive, format))
# IndexError only occurs when the format string is "%"
except IndexError:
raise ValueError("stray %% in format '%s'" % format)
_regex_cache[format] = format_regex
found = format_regex.match(data_string)
if not found:
raise ValueError("time data %r does not match format %r" %
(data_string, format))
if len(data_string) != found.end():
raise ValueError("unconverted data remains: %s" %
data_string[found.end():])
year = 1900
month = day = 1
hour = minute = second = fraction = 0
tz = -1
# Default to -1 to signify that values not known; not critical to have,
# though
week_of_year = -1
week_of_year_start = -1
# weekday and julian defaulted to -1 so as to signal need to calculate
# values
weekday = julian = -1
found_dict = found.groupdict()
for group_key in found_dict.iterkeys():
# Directives not explicitly handled below:
# c, x, X
# handled by making out of other directives
# U, W
# worthless without day of the week
if group_key == 'y':
year = int(found_dict['y'])
# Open Group specification for strptime() states that a %y
#value in the range of [00, 68] is in the century 2000, while
#[69,99] is in the century 1900
if year <= 68:
year += 2000
else:
year += 1900
elif group_key == 'Y':
year = int(found_dict['Y'])
elif group_key == 'm':
month = int(found_dict['m'])
elif group_key == 'B':
month = locale_time.f_month.index(found_dict['B'].lower())
elif group_key == 'b':
month = locale_time.a_month.index(found_dict['b'].lower())
elif group_key == 'd':
day = int(found_dict['d'])
elif group_key == 'H':
hour = int(found_dict['H'])
elif group_key == 'I':
hour = int(found_dict['I'])
ampm = found_dict.get('p', '').lower()
# If there was no AM/PM indicator, we'll treat this like AM
if ampm in ('', locale_time.am_pm[0]):
# We're in AM so the hour is correct unless we're
# looking at 12 midnight.
# 12 midnight == 12 AM == hour 0
if hour == 12:
hour = 0
elif ampm == locale_time.am_pm[1]:
# We're in PM so we need to add 12 to the hour unless
# we're looking at 12 noon.
# 12 noon == 12 PM == hour 12
if hour != 12:
hour += 12
elif group_key == 'M':
minute = int(found_dict['M'])
elif group_key == 'S':
second = int(found_dict['S'])
elif group_key == 'f':
s = found_dict['f']
# Pad to always return microseconds.
s += "0" * (6 - len(s))
fraction = int(s)
elif group_key == 'A':
weekday = locale_time.f_weekday.index(found_dict['A'].lower())
elif group_key == 'a':
weekday = locale_time.a_weekday.index(found_dict['a'].lower())
elif group_key == 'w':
weekday = int(found_dict['w'])
if weekday == 0:
weekday = 6
else:
weekday -= 1
elif group_key == 'j':
julian = int(found_dict['j'])
elif group_key in ('U', 'W'):
week_of_year = int(found_dict[group_key])
if group_key == 'U':
# U starts week on Sunday.
week_of_year_start = 6
else:
# W starts week on Monday.
week_of_year_start = 0
elif group_key == 'Z':
# Since -1 is default value only need to worry about setting tz if
# it can be something other than -1.
found_zone = found_dict['Z'].lower()
for value, tz_values in enumerate(locale_time.timezone):
if found_zone in tz_values:
# Deal with bad locale setup where timezone names are the
# same and yet time.daylight is true; too ambiguous to
# be able to tell what timezone has daylight savings
if (time.tzname[0] == time.tzname[1] and
time.daylight and found_zone not in ("utc", "gmt")):
break
else:
tz = value
break
# If we know the week of the year and what day of that week, we can figure
# out the Julian day of the year.
if julian == -1 and week_of_year != -1 and weekday != -1:
week_starts_Mon = True if week_of_year_start == 0 else False
julian = _calc_julian_from_U_or_W(year, week_of_year, weekday,
week_starts_Mon)
# Cannot pre-calculate datetime_date() since can change in Julian
# calculation and thus could have different value for the day of the week
# calculation.
if julian == -1:
# Need to add 1 to result since first day of the year is 1, not 0.
julian = datetime_date(year, month, day).toordinal() - \
datetime_date(year, 1, 1).toordinal() + 1
else: # Assume that if they bothered to include Julian day it will
# be accurate.
datetime_result = datetime_date.fromordinal((julian - 1) + datetime_date(year, 1, 1).toordinal())
year = datetime_result.year
month = datetime_result.month
day = datetime_result.day
if weekday == -1:
weekday = datetime_date(year, month, day).weekday()
return (time.struct_time((year, month, day,
hour, minute, second,
weekday, julian, tz)), fraction)
def _strptime_time(data_string, format="%a %b %d %H:%M:%S %Y"):
return _strptime(data_string, format)[0]

View File

@@ -0,0 +1,251 @@
"""Thread-local objects.
(Note that this module provides a Python version of the threading.local
class. Depending on the version of Python you're using, there may be a
faster one available. You should always import the `local` class from
`threading`.)
Thread-local objects support the management of thread-local data.
If you have data that you want to be local to a thread, simply create
a thread-local object and use its attributes:
>>> mydata = local()
>>> mydata.number = 42
>>> mydata.number
42
You can also access the local-object's dictionary:
>>> mydata.__dict__
{'number': 42}
>>> mydata.__dict__.setdefault('widgets', [])
[]
>>> mydata.widgets
[]
What's important about thread-local objects is that their data are
local to a thread. If we access the data in a different thread:
>>> log = []
>>> def f():
... items = mydata.__dict__.items()
... items.sort()
... log.append(items)
... mydata.number = 11
... log.append(mydata.number)
>>> import threading
>>> thread = threading.Thread(target=f)
>>> thread.start()
>>> thread.join()
>>> log
[[], 11]
we get different data. Furthermore, changes made in the other thread
don't affect data seen in this thread:
>>> mydata.number
42
Of course, values you get from a local object, including a __dict__
attribute, are for whatever thread was current at the time the
attribute was read. For that reason, you generally don't want to save
these values across threads, as they apply only to the thread they
came from.
You can create custom local objects by subclassing the local class:
>>> class MyLocal(local):
... number = 2
... initialized = False
... def __init__(self, **kw):
... if self.initialized:
... raise SystemError('__init__ called too many times')
... self.initialized = True
... self.__dict__.update(kw)
... def squared(self):
... return self.number ** 2
This can be useful to support default values, methods and
initialization. Note that if you define an __init__ method, it will be
called each time the local object is used in a separate thread. This
is necessary to initialize each thread's dictionary.
Now if we create a local object:
>>> mydata = MyLocal(color='red')
Now we have a default number:
>>> mydata.number
2
an initial color:
>>> mydata.color
'red'
>>> del mydata.color
And a method that operates on the data:
>>> mydata.squared()
4
As before, we can access the data in a separate thread:
>>> log = []
>>> thread = threading.Thread(target=f)
>>> thread.start()
>>> thread.join()
>>> log
[[('color', 'red'), ('initialized', True)], 11]
without affecting this thread's data:
>>> mydata.number
2
>>> mydata.color
Traceback (most recent call last):
...
AttributeError: 'MyLocal' object has no attribute 'color'
Note that subclasses can define slots, but they are not thread
local. They are shared across threads:
>>> class MyLocal(local):
... __slots__ = 'number'
>>> mydata = MyLocal()
>>> mydata.number = 42
>>> mydata.color = 'red'
So, the separate thread:
>>> thread = threading.Thread(target=f)
>>> thread.start()
>>> thread.join()
affects what we see:
>>> mydata.number
11
>>> del mydata
"""
__all__ = ["local"]
# We need to use objects from the threading module, but the threading
# module may also want to use our `local` class, if support for locals
# isn't compiled in to the `thread` module. This creates potential problems
# with circular imports. For that reason, we don't import `threading`
# until the bottom of this file (a hack sufficient to worm around the
# potential problems). Note that almost all platforms do have support for
# locals in the `thread` module, and there is no circular import problem
# then, so problems introduced by fiddling the order of imports here won't
# manifest on most boxes.
class _localbase(object):
__slots__ = '_local__key', '_local__args', '_local__lock'
def __new__(cls, *args, **kw):
self = object.__new__(cls)
key = '_local__key', 'thread.local.' + str(id(self))
object.__setattr__(self, '_local__key', key)
object.__setattr__(self, '_local__args', (args, kw))
object.__setattr__(self, '_local__lock', RLock())
if (args or kw) and (cls.__init__ is object.__init__):
raise TypeError("Initialization arguments are not supported")
# We need to create the thread dict in anticipation of
# __init__ being called, to make sure we don't call it
# again ourselves.
dict = object.__getattribute__(self, '__dict__')
current_thread().__dict__[key] = dict
return self
def _patch(self):
key = object.__getattribute__(self, '_local__key')
d = current_thread().__dict__.get(key)
if d is None:
d = {}
current_thread().__dict__[key] = d
object.__setattr__(self, '__dict__', d)
# we have a new instance dict, so call out __init__ if we have
# one
cls = type(self)
if cls.__init__ is not object.__init__:
args, kw = object.__getattribute__(self, '_local__args')
cls.__init__(self, *args, **kw)
else:
object.__setattr__(self, '__dict__', d)
class local(_localbase):
def __getattribute__(self, name):
lock = object.__getattribute__(self, '_local__lock')
lock.acquire()
try:
_patch(self)
return object.__getattribute__(self, name)
finally:
lock.release()
def __setattr__(self, name, value):
if name == '__dict__':
raise AttributeError(
"%r object attribute '__dict__' is read-only"
% self.__class__.__name__)
lock = object.__getattribute__(self, '_local__lock')
lock.acquire()
try:
_patch(self)
return object.__setattr__(self, name, value)
finally:
lock.release()
def __delattr__(self, name):
if name == '__dict__':
raise AttributeError(
"%r object attribute '__dict__' is read-only"
% self.__class__.__name__)
lock = object.__getattribute__(self, '_local__lock')
lock.acquire()
try:
_patch(self)
return object.__delattr__(self, name)
finally:
lock.release()
def __del__(self):
import threading
key = object.__getattribute__(self, '_local__key')
try:
# We use the non-locking API since we might already hold the lock
# (__del__ can be called at any point by the cyclic GC).
threads = threading._enumerate()
except:
# If enumerating the current threads fails, as it seems to do
# during shutdown, we'll skip cleanup under the assumption
# that there is nothing to clean up.
return
for thread in threads:
try:
__dict__ = thread.__dict__
except AttributeError:
# Thread is dying, rest in peace.
continue
if key in __dict__:
try:
del __dict__[key]
except KeyError:
pass # didn't have anything in this thread
from threading import current_thread, RLock

View File

@@ -0,0 +1,212 @@
# Access WeakSet through the weakref module.
# This code is separated-out because it is needed
# by abc.py to load everything else at startup.
from _weakref import ref
__all__ = ['WeakSet']
class _IterationGuard(object):
# This context manager registers itself in the current iterators of the
# weak container, such as to delay all removals until the context manager
# exits.
# This technique should be relatively thread-safe (since sets are).
def __init__(self, weakcontainer):
# Don't create cycles
self.weakcontainer = ref(weakcontainer)
def __enter__(self):
w = self.weakcontainer()
if w is not None:
w._iterating.add(self)
return self
def __exit__(self, e, t, b):
w = self.weakcontainer()
if w is not None:
s = w._iterating
s.remove(self)
if not s:
w._commit_removals()
class WeakSet(object):
def __init__(self, data=None):
self.data = set()
def _remove(item, selfref=ref(self)):
self = selfref()
if self is not None:
if self._iterating:
self._pending_removals.append(item)
else:
self.data.discard(item)
self._remove = _remove
# A list of keys to be removed
self._pending_removals = []
self._iterating = set()
if data is not None:
self.update(data)
def _commit_removals(self):
l = self._pending_removals
discard = self.data.discard
while l:
discard(l.pop())
def __iter__(self):
with _IterationGuard(self):
for itemref in self.data:
item = itemref()
if item is not None:
yield item
def __len__(self):
return sum(x() is not None for x in self.data)
def __contains__(self, item):
return ref(item) in self.data
def __reduce__(self):
return (self.__class__, (list(self),),
getattr(self, '__dict__', None))
__hash__ = None
def add(self, item):
if self._pending_removals:
self._commit_removals()
self.data.add(ref(item, self._remove))
def clear(self):
if self._pending_removals:
self._commit_removals()
self.data.clear()
def copy(self):
return self.__class__(self)
def pop(self):
if self._pending_removals:
self._commit_removals()
while True:
try:
itemref = self.data.pop()
except KeyError:
raise KeyError('pop from empty WeakSet')
item = itemref()
if item is not None:
return item
def remove(self, item):
if self._pending_removals:
self._commit_removals()
self.data.remove(ref(item))
def discard(self, item):
if self._pending_removals:
self._commit_removals()
self.data.discard(ref(item))
def update(self, other):
if self._pending_removals:
self._commit_removals()
if isinstance(other, self.__class__):
self.data.update(other.data)
else:
for element in other:
self.add(element)
def __ior__(self, other):
self.update(other)
return self
# Helper functions for simple delegating methods.
def _apply(self, other, method):
if not isinstance(other, self.__class__):
other = self.__class__(other)
newdata = method(other.data)
newset = self.__class__()
newset.data = newdata
return newset
def difference(self, other):
return self._apply(other, self.data.difference)
__sub__ = difference
def difference_update(self, other):
if self._pending_removals:
self._commit_removals()
if self is other:
self.data.clear()
else:
self.data.difference_update(ref(item) for item in other)
def __isub__(self, other):
if self._pending_removals:
self._commit_removals()
if self is other:
self.data.clear()
else:
self.data.difference_update(ref(item) for item in other)
return self
def intersection(self, other):
return self._apply(other, self.data.intersection)
__and__ = intersection
def intersection_update(self, other):
if self._pending_removals:
self._commit_removals()
self.data.intersection_update(ref(item) for item in other)
def __iand__(self, other):
if self._pending_removals:
self._commit_removals()
self.data.intersection_update(ref(item) for item in other)
return self
def issubset(self, other):
return self.data.issubset(ref(item) for item in other)
__lt__ = issubset
def __le__(self, other):
return self.data <= set(ref(item) for item in other)
def issuperset(self, other):
return self.data.issuperset(ref(item) for item in other)
__gt__ = issuperset
def __ge__(self, other):
return self.data >= set(ref(item) for item in other)
def __eq__(self, other):
if not isinstance(other, self.__class__):
return NotImplemented
return self.data == set(ref(item) for item in other)
def symmetric_difference(self, other):
return self._apply(other, self.data.symmetric_difference)
__xor__ = symmetric_difference
def symmetric_difference_update(self, other):
if self._pending_removals:
self._commit_removals()
if self is other:
self.data.clear()
else:
self.data.symmetric_difference_update(ref(item) for item in other)
def __ixor__(self, other):
if self._pending_removals:
self._commit_removals()
if self is other:
self.data.clear()
else:
self.data.symmetric_difference_update(ref(item) for item in other)
return self
def union(self, other):
return self._apply(other, self.data.union)
__or__ = union
def isdisjoint(self, other):
return len(self.intersection(other)) == 0

185
src/python/stdlib/abc.py Normal file
View File

@@ -0,0 +1,185 @@
# Copyright 2007 Google, Inc. All Rights Reserved.
# Licensed to PSF under a Contributor Agreement.
"""Abstract Base Classes (ABCs) according to PEP 3119."""
import types
from _weakrefset import WeakSet
# Instance of old-style class
class _C: pass
_InstanceType = type(_C())
def abstractmethod(funcobj):
"""A decorator indicating abstract methods.
Requires that the metaclass is ABCMeta or derived from it. A
class that has a metaclass derived from ABCMeta cannot be
instantiated unless all of its abstract methods are overridden.
The abstract methods can be called using any of the normal
'super' call mechanisms.
Usage:
class C:
__metaclass__ = ABCMeta
@abstractmethod
def my_abstract_method(self, ...):
...
"""
funcobj.__isabstractmethod__ = True
return funcobj
class abstractproperty(property):
"""A decorator indicating abstract properties.
Requires that the metaclass is ABCMeta or derived from it. A
class that has a metaclass derived from ABCMeta cannot be
instantiated unless all of its abstract properties are overridden.
The abstract properties can be called using any of the normal
'super' call mechanisms.
Usage:
class C:
__metaclass__ = ABCMeta
@abstractproperty
def my_abstract_property(self):
...
This defines a read-only property; you can also define a read-write
abstract property using the 'long' form of property declaration:
class C:
__metaclass__ = ABCMeta
def getx(self): ...
def setx(self, value): ...
x = abstractproperty(getx, setx)
"""
__isabstractmethod__ = True
class ABCMeta(type):
"""Metaclass for defining Abstract Base Classes (ABCs).
Use this metaclass to create an ABC. An ABC can be subclassed
directly, and then acts as a mix-in class. You can also register
unrelated concrete classes (even built-in classes) and unrelated
ABCs as 'virtual subclasses' -- these and their descendants will
be considered subclasses of the registering ABC by the built-in
issubclass() function, but the registering ABC won't show up in
their MRO (Method Resolution Order) nor will method
implementations defined by the registering ABC be callable (not
even via super()).
"""
# A global counter that is incremented each time a class is
# registered as a virtual subclass of anything. It forces the
# negative cache to be cleared before its next use.
_abc_invalidation_counter = 0
def __new__(mcls, name, bases, namespace):
cls = super(ABCMeta, mcls).__new__(mcls, name, bases, namespace)
# Compute set of abstract method names
abstracts = set(name
for name, value in namespace.items()
if getattr(value, "__isabstractmethod__", False))
for base in bases:
for name in getattr(base, "__abstractmethods__", set()):
value = getattr(cls, name, None)
if getattr(value, "__isabstractmethod__", False):
abstracts.add(name)
cls.__abstractmethods__ = frozenset(abstracts)
# Set up inheritance registry
cls._abc_registry = WeakSet()
cls._abc_cache = WeakSet()
cls._abc_negative_cache = WeakSet()
cls._abc_negative_cache_version = ABCMeta._abc_invalidation_counter
return cls
def register(cls, subclass):
"""Register a virtual subclass of an ABC."""
if not isinstance(subclass, (type, types.ClassType)):
raise TypeError("Can only register classes")
if issubclass(subclass, cls):
return # Already a subclass
# Subtle: test for cycles *after* testing for "already a subclass";
# this means we allow X.register(X) and interpret it as a no-op.
if issubclass(cls, subclass):
# This would create a cycle, which is bad for the algorithm below
raise RuntimeError("Refusing to create an inheritance cycle")
cls._abc_registry.add(subclass)
ABCMeta._abc_invalidation_counter += 1 # Invalidate negative cache
def _dump_registry(cls, file=None):
"""Debug helper to print the ABC registry."""
print >> file, "Class: %s.%s" % (cls.__module__, cls.__name__)
print >> file, "Inv.counter: %s" % ABCMeta._abc_invalidation_counter
for name in sorted(cls.__dict__.keys()):
if name.startswith("_abc_"):
value = getattr(cls, name)
print >> file, "%s: %r" % (name, value)
def __instancecheck__(cls, instance):
"""Override for isinstance(instance, cls)."""
# Inline the cache checking when it's simple.
subclass = getattr(instance, '__class__', None)
if subclass is not None and subclass in cls._abc_cache:
return True
subtype = type(instance)
# Old-style instances
if subtype is _InstanceType:
subtype = subclass
if subtype is subclass or subclass is None:
if (cls._abc_negative_cache_version ==
ABCMeta._abc_invalidation_counter and
subtype in cls._abc_negative_cache):
return False
# Fall back to the subclass check.
return cls.__subclasscheck__(subtype)
return (cls.__subclasscheck__(subclass) or
cls.__subclasscheck__(subtype))
def __subclasscheck__(cls, subclass):
"""Override for issubclass(subclass, cls)."""
# Check cache
if subclass in cls._abc_cache:
return True
# Check negative cache; may have to invalidate
if cls._abc_negative_cache_version < ABCMeta._abc_invalidation_counter:
# Invalidate the negative cache
cls._abc_negative_cache = WeakSet()
cls._abc_negative_cache_version = ABCMeta._abc_invalidation_counter
elif subclass in cls._abc_negative_cache:
return False
# Check the subclass hook
ok = cls.__subclasshook__(subclass)
if ok is not NotImplemented:
assert isinstance(ok, bool)
if ok:
cls._abc_cache.add(subclass)
else:
cls._abc_negative_cache.add(subclass)
return ok
# Check if it's a direct subclass
if cls in getattr(subclass, '__mro__', ()):
cls._abc_cache.add(subclass)
return True
# Check if it's a subclass of a registered class (recursive)
for rcls in cls._abc_registry:
if issubclass(subclass, rcls):
cls._abc_cache.add(subclass)
return True
# Check if it's a subclass of a subclass (recursive)
for scls in cls.__subclasses__():
if issubclass(subclass, scls):
cls._abc_cache.add(subclass)
return True
# No dice; update negative cache
cls._abc_negative_cache.add(subclass)
return False

957
src/python/stdlib/aifc.py Normal file
View File

@@ -0,0 +1,957 @@
"""Stuff to parse AIFF-C and AIFF files.
Unless explicitly stated otherwise, the description below is true
both for AIFF-C files and AIFF files.
An AIFF-C file has the following structure.
+-----------------+
| FORM |
+-----------------+
| <size> |
+----+------------+
| | AIFC |
| +------------+
| | <chunks> |
| | . |
| | . |
| | . |
+----+------------+
An AIFF file has the string "AIFF" instead of "AIFC".
A chunk consists of an identifier (4 bytes) followed by a size (4 bytes,
big endian order), followed by the data. The size field does not include
the size of the 8 byte header.
The following chunk types are recognized.
FVER
<version number of AIFF-C defining document> (AIFF-C only).
MARK
<# of markers> (2 bytes)
list of markers:
<marker ID> (2 bytes, must be > 0)
<position> (4 bytes)
<marker name> ("pstring")
COMM
<# of channels> (2 bytes)
<# of sound frames> (4 bytes)
<size of the samples> (2 bytes)
<sampling frequency> (10 bytes, IEEE 80-bit extended
floating point)
in AIFF-C files only:
<compression type> (4 bytes)
<human-readable version of compression type> ("pstring")
SSND
<offset> (4 bytes, not used by this program)
<blocksize> (4 bytes, not used by this program)
<sound data>
A pstring consists of 1 byte length, a string of characters, and 0 or 1
byte pad to make the total length even.
Usage.
Reading AIFF files:
f = aifc.open(file, 'r')
where file is either the name of a file or an open file pointer.
The open file pointer must have methods read(), seek(), and close().
In some types of audio files, if the setpos() method is not used,
the seek() method is not necessary.
This returns an instance of a class with the following public methods:
getnchannels() -- returns number of audio channels (1 for
mono, 2 for stereo)
getsampwidth() -- returns sample width in bytes
getframerate() -- returns sampling frequency
getnframes() -- returns number of audio frames
getcomptype() -- returns compression type ('NONE' for AIFF files)
getcompname() -- returns human-readable version of
compression type ('not compressed' for AIFF files)
getparams() -- returns a tuple consisting of all of the
above in the above order
getmarkers() -- get the list of marks in the audio file or None
if there are no marks
getmark(id) -- get mark with the specified id (raises an error
if the mark does not exist)
readframes(n) -- returns at most n frames of audio
rewind() -- rewind to the beginning of the audio stream
setpos(pos) -- seek to the specified position
tell() -- return the current position
close() -- close the instance (make it unusable)
The position returned by tell(), the position given to setpos() and
the position of marks are all compatible and have nothing to do with
the actual position in the file.
The close() method is called automatically when the class instance
is destroyed.
Writing AIFF files:
f = aifc.open(file, 'w')
where file is either the name of a file or an open file pointer.
The open file pointer must have methods write(), tell(), seek(), and
close().
This returns an instance of a class with the following public methods:
aiff() -- create an AIFF file (AIFF-C default)
aifc() -- create an AIFF-C file
setnchannels(n) -- set the number of channels
setsampwidth(n) -- set the sample width
setframerate(n) -- set the frame rate
setnframes(n) -- set the number of frames
setcomptype(type, name)
-- set the compression type and the
human-readable compression type
setparams(tuple)
-- set all parameters at once
setmark(id, pos, name)
-- add specified mark to the list of marks
tell() -- return current position in output file (useful
in combination with setmark())
writeframesraw(data)
-- write audio frames without pathing up the
file header
writeframes(data)
-- write audio frames and patch up the file header
close() -- patch up the file header and close the
output file
You should set the parameters before the first writeframesraw or
writeframes. The total number of frames does not need to be set,
but when it is set to the correct value, the header does not have to
be patched up.
It is best to first set all parameters, perhaps possibly the
compression type, and then write audio frames using writeframesraw.
When all frames have been written, either call writeframes('') or
close() to patch up the sizes in the header.
Marks can be added anytime. If there are any marks, ypu must call
close() after all frames have been written.
The close() method is called automatically when the class instance
is destroyed.
When a file is opened with the extension '.aiff', an AIFF file is
written, otherwise an AIFF-C file is written. This default can be
changed by calling aiff() or aifc() before the first writeframes or
writeframesraw.
"""
import struct
import __builtin__
__all__ = ["Error","open","openfp"]
class Error(Exception):
pass
_AIFC_version = 0xA2805140L # Version 1 of AIFF-C
def _read_long(file):
try:
return struct.unpack('>l', file.read(4))[0]
except struct.error:
raise EOFError
def _read_ulong(file):
try:
return struct.unpack('>L', file.read(4))[0]
except struct.error:
raise EOFError
def _read_short(file):
try:
return struct.unpack('>h', file.read(2))[0]
except struct.error:
raise EOFError
def _read_string(file):
length = ord(file.read(1))
if length == 0:
data = ''
else:
data = file.read(length)
if length & 1 == 0:
dummy = file.read(1)
return data
_HUGE_VAL = 1.79769313486231e+308 # See <limits.h>
def _read_float(f): # 10 bytes
expon = _read_short(f) # 2 bytes
sign = 1
if expon < 0:
sign = -1
expon = expon + 0x8000
himant = _read_ulong(f) # 4 bytes
lomant = _read_ulong(f) # 4 bytes
if expon == himant == lomant == 0:
f = 0.0
elif expon == 0x7FFF:
f = _HUGE_VAL
else:
expon = expon - 16383
f = (himant * 0x100000000L + lomant) * pow(2.0, expon - 63)
return sign * f
def _write_short(f, x):
f.write(struct.pack('>h', x))
def _write_long(f, x):
f.write(struct.pack('>L', x))
def _write_string(f, s):
if len(s) > 255:
raise ValueError("string exceeds maximum pstring length")
f.write(chr(len(s)))
f.write(s)
if len(s) & 1 == 0:
f.write(chr(0))
def _write_float(f, x):
import math
if x < 0:
sign = 0x8000
x = x * -1
else:
sign = 0
if x == 0:
expon = 0
himant = 0
lomant = 0
else:
fmant, expon = math.frexp(x)
if expon > 16384 or fmant >= 1: # Infinity or NaN
expon = sign|0x7FFF
himant = 0
lomant = 0
else: # Finite
expon = expon + 16382
if expon < 0: # denormalized
fmant = math.ldexp(fmant, expon)
expon = 0
expon = expon | sign
fmant = math.ldexp(fmant, 32)
fsmant = math.floor(fmant)
himant = long(fsmant)
fmant = math.ldexp(fmant - fsmant, 32)
fsmant = math.floor(fmant)
lomant = long(fsmant)
_write_short(f, expon)
_write_long(f, himant)
_write_long(f, lomant)
from chunk import Chunk
class Aifc_read:
# Variables used in this class:
#
# These variables are available to the user though appropriate
# methods of this class:
# _file -- the open file with methods read(), close(), and seek()
# set through the __init__() method
# _nchannels -- the number of audio channels
# available through the getnchannels() method
# _nframes -- the number of audio frames
# available through the getnframes() method
# _sampwidth -- the number of bytes per audio sample
# available through the getsampwidth() method
# _framerate -- the sampling frequency
# available through the getframerate() method
# _comptype -- the AIFF-C compression type ('NONE' if AIFF)
# available through the getcomptype() method
# _compname -- the human-readable AIFF-C compression type
# available through the getcomptype() method
# _markers -- the marks in the audio file
# available through the getmarkers() and getmark()
# methods
# _soundpos -- the position in the audio stream
# available through the tell() method, set through the
# setpos() method
#
# These variables are used internally only:
# _version -- the AIFF-C version number
# _decomp -- the decompressor from builtin module cl
# _comm_chunk_read -- 1 iff the COMM chunk has been read
# _aifc -- 1 iff reading an AIFF-C file
# _ssnd_seek_needed -- 1 iff positioned correctly in audio
# file for readframes()
# _ssnd_chunk -- instantiation of a chunk class for the SSND chunk
# _framesize -- size of one frame in the file
def initfp(self, file):
self._version = 0
self._decomp = None
self._convert = None
self._markers = []
self._soundpos = 0
self._file = file
chunk = Chunk(file)
if chunk.getname() != 'FORM':
raise Error, 'file does not start with FORM id'
formdata = chunk.read(4)
if formdata == 'AIFF':
self._aifc = 0
elif formdata == 'AIFC':
self._aifc = 1
else:
raise Error, 'not an AIFF or AIFF-C file'
self._comm_chunk_read = 0
while 1:
self._ssnd_seek_needed = 1
try:
chunk = Chunk(self._file)
except EOFError:
break
chunkname = chunk.getname()
if chunkname == 'COMM':
self._read_comm_chunk(chunk)
self._comm_chunk_read = 1
elif chunkname == 'SSND':
self._ssnd_chunk = chunk
dummy = chunk.read(8)
self._ssnd_seek_needed = 0
elif chunkname == 'FVER':
self._version = _read_ulong(chunk)
elif chunkname == 'MARK':
self._readmark(chunk)
chunk.skip()
if not self._comm_chunk_read or not self._ssnd_chunk:
raise Error, 'COMM chunk and/or SSND chunk missing'
if self._aifc and self._decomp:
import cl
params = [cl.ORIGINAL_FORMAT, 0,
cl.BITS_PER_COMPONENT, self._sampwidth * 8,
cl.FRAME_RATE, self._framerate]
if self._nchannels == 1:
params[1] = cl.MONO
elif self._nchannels == 2:
params[1] = cl.STEREO_INTERLEAVED
else:
raise Error, 'cannot compress more than 2 channels'
self._decomp.SetParams(params)
def __init__(self, f):
if type(f) == type(''):
f = __builtin__.open(f, 'rb')
# else, assume it is an open file object already
self.initfp(f)
#
# User visible methods.
#
def getfp(self):
return self._file
def rewind(self):
self._ssnd_seek_needed = 1
self._soundpos = 0
def close(self):
if self._decomp:
self._decomp.CloseDecompressor()
self._decomp = None
self._file.close()
def tell(self):
return self._soundpos
def getnchannels(self):
return self._nchannels
def getnframes(self):
return self._nframes
def getsampwidth(self):
return self._sampwidth
def getframerate(self):
return self._framerate
def getcomptype(self):
return self._comptype
def getcompname(self):
return self._compname
## def getversion(self):
## return self._version
def getparams(self):
return self.getnchannels(), self.getsampwidth(), \
self.getframerate(), self.getnframes(), \
self.getcomptype(), self.getcompname()
def getmarkers(self):
if len(self._markers) == 0:
return None
return self._markers
def getmark(self, id):
for marker in self._markers:
if id == marker[0]:
return marker
raise Error, 'marker %r does not exist' % (id,)
def setpos(self, pos):
if pos < 0 or pos > self._nframes:
raise Error, 'position not in range'
self._soundpos = pos
self._ssnd_seek_needed = 1
def readframes(self, nframes):
if self._ssnd_seek_needed:
self._ssnd_chunk.seek(0)
dummy = self._ssnd_chunk.read(8)
pos = self._soundpos * self._framesize
if pos:
self._ssnd_chunk.seek(pos + 8)
self._ssnd_seek_needed = 0
if nframes == 0:
return ''
data = self._ssnd_chunk.read(nframes * self._framesize)
if self._convert and data:
data = self._convert(data)
self._soundpos = self._soundpos + len(data) // (self._nchannels * self._sampwidth)
return data
#
# Internal methods.
#
def _decomp_data(self, data):
import cl
dummy = self._decomp.SetParam(cl.FRAME_BUFFER_SIZE,
len(data) * 2)
return self._decomp.Decompress(len(data) // self._nchannels,
data)
def _ulaw2lin(self, data):
import audioop
return audioop.ulaw2lin(data, 2)
def _adpcm2lin(self, data):
import audioop
if not hasattr(self, '_adpcmstate'):
# first time
self._adpcmstate = None
data, self._adpcmstate = audioop.adpcm2lin(data, 2,
self._adpcmstate)
return data
def _read_comm_chunk(self, chunk):
self._nchannels = _read_short(chunk)
self._nframes = _read_long(chunk)
self._sampwidth = (_read_short(chunk) + 7) // 8
self._framerate = int(_read_float(chunk))
self._framesize = self._nchannels * self._sampwidth
if self._aifc:
#DEBUG: SGI's soundeditor produces a bad size :-(
kludge = 0
if chunk.chunksize == 18:
kludge = 1
print 'Warning: bad COMM chunk size'
chunk.chunksize = 23
#DEBUG end
self._comptype = chunk.read(4)
#DEBUG start
if kludge:
length = ord(chunk.file.read(1))
if length & 1 == 0:
length = length + 1
chunk.chunksize = chunk.chunksize + length
chunk.file.seek(-1, 1)
#DEBUG end
self._compname = _read_string(chunk)
if self._comptype != 'NONE':
if self._comptype == 'G722':
try:
import audioop
except ImportError:
pass
else:
self._convert = self._adpcm2lin
self._framesize = self._framesize // 4
return
# for ULAW and ALAW try Compression Library
try:
import cl
except ImportError:
if self._comptype == 'ULAW':
try:
import audioop
self._convert = self._ulaw2lin
self._framesize = self._framesize // 2
return
except ImportError:
pass
raise Error, 'cannot read compressed AIFF-C files'
if self._comptype == 'ULAW':
scheme = cl.G711_ULAW
self._framesize = self._framesize // 2
elif self._comptype == 'ALAW':
scheme = cl.G711_ALAW
self._framesize = self._framesize // 2
else:
raise Error, 'unsupported compression type'
self._decomp = cl.OpenDecompressor(scheme)
self._convert = self._decomp_data
else:
self._comptype = 'NONE'
self._compname = 'not compressed'
def _readmark(self, chunk):
nmarkers = _read_short(chunk)
# Some files appear to contain invalid counts.
# Cope with this by testing for EOF.
try:
for i in range(nmarkers):
id = _read_short(chunk)
pos = _read_long(chunk)
name = _read_string(chunk)
if pos or name:
# some files appear to have
# dummy markers consisting of
# a position 0 and name ''
self._markers.append((id, pos, name))
except EOFError:
print 'Warning: MARK chunk contains only',
print len(self._markers),
if len(self._markers) == 1: print 'marker',
else: print 'markers',
print 'instead of', nmarkers
class Aifc_write:
# Variables used in this class:
#
# These variables are user settable through appropriate methods
# of this class:
# _file -- the open file with methods write(), close(), tell(), seek()
# set through the __init__() method
# _comptype -- the AIFF-C compression type ('NONE' in AIFF)
# set through the setcomptype() or setparams() method
# _compname -- the human-readable AIFF-C compression type
# set through the setcomptype() or setparams() method
# _nchannels -- the number of audio channels
# set through the setnchannels() or setparams() method
# _sampwidth -- the number of bytes per audio sample
# set through the setsampwidth() or setparams() method
# _framerate -- the sampling frequency
# set through the setframerate() or setparams() method
# _nframes -- the number of audio frames written to the header
# set through the setnframes() or setparams() method
# _aifc -- whether we're writing an AIFF-C file or an AIFF file
# set through the aifc() method, reset through the
# aiff() method
#
# These variables are used internally only:
# _version -- the AIFF-C version number
# _comp -- the compressor from builtin module cl
# _nframeswritten -- the number of audio frames actually written
# _datalength -- the size of the audio samples written to the header
# _datawritten -- the size of the audio samples actually written
def __init__(self, f):
if type(f) == type(''):
filename = f
f = __builtin__.open(f, 'wb')
else:
# else, assume it is an open file object already
filename = '???'
self.initfp(f)
if filename[-5:] == '.aiff':
self._aifc = 0
else:
self._aifc = 1
def initfp(self, file):
self._file = file
self._version = _AIFC_version
self._comptype = 'NONE'
self._compname = 'not compressed'
self._comp = None
self._convert = None
self._nchannels = 0
self._sampwidth = 0
self._framerate = 0
self._nframes = 0
self._nframeswritten = 0
self._datawritten = 0
self._datalength = 0
self._markers = []
self._marklength = 0
self._aifc = 1 # AIFF-C is default
def __del__(self):
if self._file:
self.close()
#
# User visible methods.
#
def aiff(self):
if self._nframeswritten:
raise Error, 'cannot change parameters after starting to write'
self._aifc = 0
def aifc(self):
if self._nframeswritten:
raise Error, 'cannot change parameters after starting to write'
self._aifc = 1
def setnchannels(self, nchannels):
if self._nframeswritten:
raise Error, 'cannot change parameters after starting to write'
if nchannels < 1:
raise Error, 'bad # of channels'
self._nchannels = nchannels
def getnchannels(self):
if not self._nchannels:
raise Error, 'number of channels not set'
return self._nchannels
def setsampwidth(self, sampwidth):
if self._nframeswritten:
raise Error, 'cannot change parameters after starting to write'
if sampwidth < 1 or sampwidth > 4:
raise Error, 'bad sample width'
self._sampwidth = sampwidth
def getsampwidth(self):
if not self._sampwidth:
raise Error, 'sample width not set'
return self._sampwidth
def setframerate(self, framerate):
if self._nframeswritten:
raise Error, 'cannot change parameters after starting to write'
if framerate <= 0:
raise Error, 'bad frame rate'
self._framerate = framerate
def getframerate(self):
if not self._framerate:
raise Error, 'frame rate not set'
return self._framerate
def setnframes(self, nframes):
if self._nframeswritten:
raise Error, 'cannot change parameters after starting to write'
self._nframes = nframes
def getnframes(self):
return self._nframeswritten
def setcomptype(self, comptype, compname):
if self._nframeswritten:
raise Error, 'cannot change parameters after starting to write'
if comptype not in ('NONE', 'ULAW', 'ALAW', 'G722'):
raise Error, 'unsupported compression type'
self._comptype = comptype
self._compname = compname
def getcomptype(self):
return self._comptype
def getcompname(self):
return self._compname
## def setversion(self, version):
## if self._nframeswritten:
## raise Error, 'cannot change parameters after starting to write'
## self._version = version
def setparams(self, info):
nchannels, sampwidth, framerate, nframes, comptype, compname = info
if self._nframeswritten:
raise Error, 'cannot change parameters after starting to write'
if comptype not in ('NONE', 'ULAW', 'ALAW', 'G722'):
raise Error, 'unsupported compression type'
self.setnchannels(nchannels)
self.setsampwidth(sampwidth)
self.setframerate(framerate)
self.setnframes(nframes)
self.setcomptype(comptype, compname)
def getparams(self):
if not self._nchannels or not self._sampwidth or not self._framerate:
raise Error, 'not all parameters set'
return self._nchannels, self._sampwidth, self._framerate, \
self._nframes, self._comptype, self._compname
def setmark(self, id, pos, name):
if id <= 0:
raise Error, 'marker ID must be > 0'
if pos < 0:
raise Error, 'marker position must be >= 0'
if type(name) != type(''):
raise Error, 'marker name must be a string'
for i in range(len(self._markers)):
if id == self._markers[i][0]:
self._markers[i] = id, pos, name
return
self._markers.append((id, pos, name))
def getmark(self, id):
for marker in self._markers:
if id == marker[0]:
return marker
raise Error, 'marker %r does not exist' % (id,)
def getmarkers(self):
if len(self._markers) == 0:
return None
return self._markers
def tell(self):
return self._nframeswritten
def writeframesraw(self, data):
self._ensure_header_written(len(data))
nframes = len(data) // (self._sampwidth * self._nchannels)
if self._convert:
data = self._convert(data)
self._file.write(data)
self._nframeswritten = self._nframeswritten + nframes
self._datawritten = self._datawritten + len(data)
def writeframes(self, data):
self.writeframesraw(data)
if self._nframeswritten != self._nframes or \
self._datalength != self._datawritten:
self._patchheader()
def close(self):
self._ensure_header_written(0)
if self._datawritten & 1:
# quick pad to even size
self._file.write(chr(0))
self._datawritten = self._datawritten + 1
self._writemarkers()
if self._nframeswritten != self._nframes or \
self._datalength != self._datawritten or \
self._marklength:
self._patchheader()
if self._comp:
self._comp.CloseCompressor()
self._comp = None
# Prevent ref cycles
self._convert = None
self._file.close()
#
# Internal methods.
#
def _comp_data(self, data):
import cl
dummy = self._comp.SetParam(cl.FRAME_BUFFER_SIZE, len(data))
dummy = self._comp.SetParam(cl.COMPRESSED_BUFFER_SIZE, len(data))
return self._comp.Compress(self._nframes, data)
def _lin2ulaw(self, data):
import audioop
return audioop.lin2ulaw(data, 2)
def _lin2adpcm(self, data):
import audioop
if not hasattr(self, '_adpcmstate'):
self._adpcmstate = None
data, self._adpcmstate = audioop.lin2adpcm(data, 2,
self._adpcmstate)
return data
def _ensure_header_written(self, datasize):
if not self._nframeswritten:
if self._comptype in ('ULAW', 'ALAW'):
if not self._sampwidth:
self._sampwidth = 2
if self._sampwidth != 2:
raise Error, 'sample width must be 2 when compressing with ULAW or ALAW'
if self._comptype == 'G722':
if not self._sampwidth:
self._sampwidth = 2
if self._sampwidth != 2:
raise Error, 'sample width must be 2 when compressing with G7.22 (ADPCM)'
if not self._nchannels:
raise Error, '# channels not specified'
if not self._sampwidth:
raise Error, 'sample width not specified'
if not self._framerate:
raise Error, 'sampling rate not specified'
self._write_header(datasize)
def _init_compression(self):
if self._comptype == 'G722':
self._convert = self._lin2adpcm
return
try:
import cl
except ImportError:
if self._comptype == 'ULAW':
try:
import audioop
self._convert = self._lin2ulaw
return
except ImportError:
pass
raise Error, 'cannot write compressed AIFF-C files'
if self._comptype == 'ULAW':
scheme = cl.G711_ULAW
elif self._comptype == 'ALAW':
scheme = cl.G711_ALAW
else:
raise Error, 'unsupported compression type'
self._comp = cl.OpenCompressor(scheme)
params = [cl.ORIGINAL_FORMAT, 0,
cl.BITS_PER_COMPONENT, self._sampwidth * 8,
cl.FRAME_RATE, self._framerate,
cl.FRAME_BUFFER_SIZE, 100,
cl.COMPRESSED_BUFFER_SIZE, 100]
if self._nchannels == 1:
params[1] = cl.MONO
elif self._nchannels == 2:
params[1] = cl.STEREO_INTERLEAVED
else:
raise Error, 'cannot compress more than 2 channels'
self._comp.SetParams(params)
# the compressor produces a header which we ignore
dummy = self._comp.Compress(0, '')
self._convert = self._comp_data
def _write_header(self, initlength):
if self._aifc and self._comptype != 'NONE':
self._init_compression()
self._file.write('FORM')
if not self._nframes:
self._nframes = initlength // (self._nchannels * self._sampwidth)
self._datalength = self._nframes * self._nchannels * self._sampwidth
if self._datalength & 1:
self._datalength = self._datalength + 1
if self._aifc:
if self._comptype in ('ULAW', 'ALAW'):
self._datalength = self._datalength // 2
if self._datalength & 1:
self._datalength = self._datalength + 1
elif self._comptype == 'G722':
self._datalength = (self._datalength + 3) // 4
if self._datalength & 1:
self._datalength = self._datalength + 1
self._form_length_pos = self._file.tell()
commlength = self._write_form_length(self._datalength)
if self._aifc:
self._file.write('AIFC')
self._file.write('FVER')
_write_long(self._file, 4)
_write_long(self._file, self._version)
else:
self._file.write('AIFF')
self._file.write('COMM')
_write_long(self._file, commlength)
_write_short(self._file, self._nchannels)
self._nframes_pos = self._file.tell()
_write_long(self._file, self._nframes)
_write_short(self._file, self._sampwidth * 8)
_write_float(self._file, self._framerate)
if self._aifc:
self._file.write(self._comptype)
_write_string(self._file, self._compname)
self._file.write('SSND')
self._ssnd_length_pos = self._file.tell()
_write_long(self._file, self._datalength + 8)
_write_long(self._file, 0)
_write_long(self._file, 0)
def _write_form_length(self, datalength):
if self._aifc:
commlength = 18 + 5 + len(self._compname)
if commlength & 1:
commlength = commlength + 1
verslength = 12
else:
commlength = 18
verslength = 0
_write_long(self._file, 4 + verslength + self._marklength + \
8 + commlength + 16 + datalength)
return commlength
def _patchheader(self):
curpos = self._file.tell()
if self._datawritten & 1:
datalength = self._datawritten + 1
self._file.write(chr(0))
else:
datalength = self._datawritten
if datalength == self._datalength and \
self._nframes == self._nframeswritten and \
self._marklength == 0:
self._file.seek(curpos, 0)
return
self._file.seek(self._form_length_pos, 0)
dummy = self._write_form_length(datalength)
self._file.seek(self._nframes_pos, 0)
_write_long(self._file, self._nframeswritten)
self._file.seek(self._ssnd_length_pos, 0)
_write_long(self._file, datalength + 8)
self._file.seek(curpos, 0)
self._nframes = self._nframeswritten
self._datalength = datalength
def _writemarkers(self):
if len(self._markers) == 0:
return
self._file.write('MARK')
length = 2
for marker in self._markers:
id, pos, name = marker
length = length + len(name) + 1 + 6
if len(name) & 1 == 0:
length = length + 1
_write_long(self._file, length)
self._marklength = length + 8
_write_short(self._file, len(self._markers))
for marker in self._markers:
id, pos, name = marker
_write_short(self._file, id)
_write_long(self._file, pos)
_write_string(self._file, name)
def open(f, mode=None):
if mode is None:
if hasattr(f, 'mode'):
mode = f.mode
else:
mode = 'rb'
if mode in ('r', 'rb'):
return Aifc_read(f)
elif mode in ('w', 'wb'):
return Aifc_write(f)
else:
raise Error, "mode must be 'r', 'rb', 'w', or 'wb'"
openfp = open # B/W compatibility
if __name__ == '__main__':
import sys
if not sys.argv[1:]:
sys.argv.append('/usr/demos/data/audio/bach.aiff')
fn = sys.argv[1]
f = open(fn, 'r')
print "Reading", fn
print "nchannels =", f.getnchannels()
print "nframes =", f.getnframes()
print "sampwidth =", f.getsampwidth()
print "framerate =", f.getframerate()
print "comptype =", f.getcomptype()
print "compname =", f.getcompname()
if sys.argv[2:]:
gn = sys.argv[2]
print "Writing", gn
g = open(gn, 'w')
g.setparams(f.getparams())
while 1:
data = f.readframes(1024)
if not data:
break
g.writeframes(data)
g.close()
f.close()
print "Done."

View File

@@ -0,0 +1,4 @@
import webbrowser
webbrowser.open("http://xkcd.com/353/")

View File

@@ -0,0 +1,83 @@
"""Generic interface to all dbm clones.
Instead of
import dbm
d = dbm.open(file, 'w', 0666)
use
import anydbm
d = anydbm.open(file, 'w')
The returned object is a dbhash, gdbm, dbm or dumbdbm object,
dependent on the type of database being opened (determined by whichdb
module) in the case of an existing dbm. If the dbm does not exist and
the create or new flag ('c' or 'n') was specified, the dbm type will
be determined by the availability of the modules (tested in the above
order).
It has the following interface (key and data are strings):
d[key] = data # store data at key (may override data at
# existing key)
data = d[key] # retrieve data at key (raise KeyError if no
# such key)
del d[key] # delete data stored at key (raises KeyError
# if no such key)
flag = key in d # true if the key exists
list = d.keys() # return a list of all existing keys (slow!)
Future versions may change the order in which implementations are
tested for existence, add interfaces to other dbm-like
implementations.
The open function has an optional second argument. This can be 'r',
for read-only access, 'w', for read-write access of an existing
database, 'c' for read-write access to a new or existing database, and
'n' for read-write access to a new database. The default is 'r'.
Note: 'r' and 'w' fail if the database doesn't exist; 'c' creates it
only if it doesn't exist; and 'n' always creates a new database.
"""
class error(Exception):
pass
_names = ['dbhash', 'gdbm', 'dbm', 'dumbdbm']
_errors = [error]
_defaultmod = None
for _name in _names:
try:
_mod = __import__(_name)
except ImportError:
continue
if not _defaultmod:
_defaultmod = _mod
_errors.append(_mod.error)
if not _defaultmod:
raise ImportError, "no dbm clone found; tried %s" % _names
error = tuple(_errors)
def open(file, flag = 'r', mode = 0666):
# guess the type of an existing database
from whichdb import whichdb
result=whichdb(file)
if result is None:
# db doesn't exist
if 'c' in flag or 'n' in flag:
# file doesn't exist and the new
# flag was used so use default type
mod = _defaultmod
else:
raise error, "need 'c' or 'n' flag to open new db"
elif result == "":
# db type cannot be determined
raise error, "db type could not be determined"
else:
mod = __import__(result)
return mod.open(file, flag, mode)

File diff suppressed because it is too large Load Diff

313
src/python/stdlib/ast.py Normal file
View File

@@ -0,0 +1,313 @@
# -*- coding: utf-8 -*-
"""
ast
~~~
The `ast` module helps Python applications to process trees of the Python
abstract syntax grammar. The abstract syntax itself might change with
each Python release; this module helps to find out programmatically what
the current grammar looks like and allows modifications of it.
An abstract syntax tree can be generated by passing `ast.PyCF_ONLY_AST` as
a flag to the `compile()` builtin function or by using the `parse()`
function from this module. The result will be a tree of objects whose
classes all inherit from `ast.AST`.
A modified abstract syntax tree can be compiled into a Python code object
using the built-in `compile()` function.
Additionally various helper functions are provided that make working with
the trees simpler. The main intention of the helper functions and this
module in general is to provide an easy to use interface for libraries
that work tightly with the python syntax (template engines for example).
:copyright: Copyright 2008 by Armin Ronacher.
:license: Python License.
"""
from _ast import *
from _ast import __version__
def parse(expr, filename='<unknown>', mode='exec'):
"""
Parse an expression into an AST node.
Equivalent to compile(expr, filename, mode, PyCF_ONLY_AST).
"""
return compile(expr, filename, mode, PyCF_ONLY_AST)
def literal_eval(node_or_string):
"""
Safely evaluate an expression node or a string containing a Python
expression. The string or node provided may only consist of the following
Python literal structures: strings, numbers, tuples, lists, dicts, booleans,
and None.
"""
_safe_names = {'None': None, 'True': True, 'False': False}
if isinstance(node_or_string, basestring):
node_or_string = parse(node_or_string, mode='eval')
if isinstance(node_or_string, Expression):
node_or_string = node_or_string.body
def _convert(node):
if isinstance(node, Str):
return node.s
elif isinstance(node, Num):
return node.n
elif isinstance(node, Tuple):
return tuple(map(_convert, node.elts))
elif isinstance(node, List):
return list(map(_convert, node.elts))
elif isinstance(node, Dict):
return dict((_convert(k), _convert(v)) for k, v
in zip(node.keys, node.values))
elif isinstance(node, Name):
if node.id in _safe_names:
return _safe_names[node.id]
elif isinstance(node, BinOp) and \
isinstance(node.op, (Add, Sub)) and \
isinstance(node.right, Num) and \
isinstance(node.right.n, complex) and \
isinstance(node.left, Num) and \
isinstance(node.left.n, (int, long, float)):
left = node.left.n
right = node.right.n
if isinstance(node.op, Add):
return left + right
else:
return left - right
raise ValueError('malformed string')
return _convert(node_or_string)
def dump(node, annotate_fields=True, include_attributes=False):
"""
Return a formatted dump of the tree in *node*. This is mainly useful for
debugging purposes. The returned string will show the names and the values
for fields. This makes the code impossible to evaluate, so if evaluation is
wanted *annotate_fields* must be set to False. Attributes such as line
numbers and column offsets are not dumped by default. If this is wanted,
*include_attributes* can be set to True.
"""
def _format(node):
if isinstance(node, AST):
fields = [(a, _format(b)) for a, b in iter_fields(node)]
rv = '%s(%s' % (node.__class__.__name__, ', '.join(
('%s=%s' % field for field in fields)
if annotate_fields else
(b for a, b in fields)
))
if include_attributes and node._attributes:
rv += fields and ', ' or ' '
rv += ', '.join('%s=%s' % (a, _format(getattr(node, a)))
for a in node._attributes)
return rv + ')'
elif isinstance(node, list):
return '[%s]' % ', '.join(_format(x) for x in node)
return repr(node)
if not isinstance(node, AST):
raise TypeError('expected AST, got %r' % node.__class__.__name__)
return _format(node)
def copy_location(new_node, old_node):
"""
Copy source location (`lineno` and `col_offset` attributes) from
*old_node* to *new_node* if possible, and return *new_node*.
"""
for attr in 'lineno', 'col_offset':
if attr in old_node._attributes and attr in new_node._attributes \
and hasattr(old_node, attr):
setattr(new_node, attr, getattr(old_node, attr))
return new_node
def fix_missing_locations(node):
"""
When you compile a node tree with compile(), the compiler expects lineno and
col_offset attributes for every node that supports them. This is rather
tedious to fill in for generated nodes, so this helper adds these attributes
recursively where not already set, by setting them to the values of the
parent node. It works recursively starting at *node*.
"""
def _fix(node, lineno, col_offset):
if 'lineno' in node._attributes:
if not hasattr(node, 'lineno'):
node.lineno = lineno
else:
lineno = node.lineno
if 'col_offset' in node._attributes:
if not hasattr(node, 'col_offset'):
node.col_offset = col_offset
else:
col_offset = node.col_offset
for child in iter_child_nodes(node):
_fix(child, lineno, col_offset)
_fix(node, 1, 0)
return node
def increment_lineno(node, n=1):
"""
Increment the line number of each node in the tree starting at *node* by *n*.
This is useful to "move code" to a different location in a file.
"""
if 'lineno' in node._attributes:
node.lineno = getattr(node, 'lineno', 0) + n
for child in walk(node):
if 'lineno' in child._attributes:
child.lineno = getattr(child, 'lineno', 0) + n
return node
def iter_fields(node):
"""
Yield a tuple of ``(fieldname, value)`` for each field in ``node._fields``
that is present on *node*.
"""
for field in node._fields:
try:
yield field, getattr(node, field)
except AttributeError:
pass
def iter_child_nodes(node):
"""
Yield all direct child nodes of *node*, that is, all fields that are nodes
and all items of fields that are lists of nodes.
"""
for name, field in iter_fields(node):
if isinstance(field, AST):
yield field
elif isinstance(field, list):
for item in field:
if isinstance(item, AST):
yield item
def get_docstring(node, clean=True):
"""
Return the docstring for the given node or None if no docstring can
be found. If the node provided does not have docstrings a TypeError
will be raised.
"""
if not isinstance(node, (FunctionDef, ClassDef, Module)):
raise TypeError("%r can't have docstrings" % node.__class__.__name__)
if node.body and isinstance(node.body[0], Expr) and \
isinstance(node.body[0].value, Str):
if clean:
import inspect
return inspect.cleandoc(node.body[0].value.s)
return node.body[0].value.s
def walk(node):
"""
Recursively yield all child nodes of *node*, in no specified order. This is
useful if you only want to modify nodes in place and don't care about the
context.
"""
from collections import deque
todo = deque([node])
while todo:
node = todo.popleft()
todo.extend(iter_child_nodes(node))
yield node
class NodeVisitor(object):
"""
A node visitor base class that walks the abstract syntax tree and calls a
visitor function for every node found. This function may return a value
which is forwarded by the `visit` method.
This class is meant to be subclassed, with the subclass adding visitor
methods.
Per default the visitor functions for the nodes are ``'visit_'`` +
class name of the node. So a `TryFinally` node visit function would
be `visit_TryFinally`. This behavior can be changed by overriding
the `visit` method. If no visitor function exists for a node
(return value `None`) the `generic_visit` visitor is used instead.
Don't use the `NodeVisitor` if you want to apply changes to nodes during
traversing. For this a special visitor exists (`NodeTransformer`) that
allows modifications.
"""
def visit(self, node):
"""Visit a node."""
method = 'visit_' + node.__class__.__name__
visitor = getattr(self, method, self.generic_visit)
return visitor(node)
def generic_visit(self, node):
"""Called if no explicit visitor function exists for a node."""
for field, value in iter_fields(node):
if isinstance(value, list):
for item in value:
if isinstance(item, AST):
self.visit(item)
elif isinstance(value, AST):
self.visit(value)
class NodeTransformer(NodeVisitor):
"""
A :class:`NodeVisitor` subclass that walks the abstract syntax tree and
allows modification of nodes.
The `NodeTransformer` will walk the AST and use the return value of the
visitor methods to replace or remove the old node. If the return value of
the visitor method is ``None``, the node will be removed from its location,
otherwise it is replaced with the return value. The return value may be the
original node in which case no replacement takes place.
Here is an example transformer that rewrites all occurrences of name lookups
(``foo``) to ``data['foo']``::
class RewriteName(NodeTransformer):
def visit_Name(self, node):
return copy_location(Subscript(
value=Name(id='data', ctx=Load()),
slice=Index(value=Str(s=node.id)),
ctx=node.ctx
), node)
Keep in mind that if the node you're operating on has child nodes you must
either transform the child nodes yourself or call the :meth:`generic_visit`
method for the node first.
For nodes that were part of a collection of statements (that applies to all
statement nodes), the visitor may also return a list of nodes rather than
just a single node.
Usually you use the transformer like this::
node = YourTransformer().visit(node)
"""
def generic_visit(self, node):
for field, old_value in iter_fields(node):
old_value = getattr(node, field, None)
if isinstance(old_value, list):
new_values = []
for value in old_value:
if isinstance(value, AST):
value = self.visit(value)
if value is None:
continue
elif not isinstance(value, AST):
new_values.extend(value)
continue
new_values.append(value)
old_value[:] = new_values
elif isinstance(old_value, AST):
new_node = self.visit(old_value)
if new_node is None:
delattr(node, field)
else:
setattr(node, field, new_node)
return node

View File

@@ -0,0 +1,314 @@
# -*- Mode: Python; tab-width: 4 -*-
# Id: asynchat.py,v 2.26 2000/09/07 22:29:26 rushing Exp
# Author: Sam Rushing <rushing@nightmare.com>
# ======================================================================
# Copyright 1996 by Sam Rushing
#
# All Rights Reserved
#
# Permission to use, copy, modify, and distribute this software and
# its documentation for any purpose and without fee is hereby
# granted, provided that the above copyright notice appear in all
# copies and that both that copyright notice and this permission
# notice appear in supporting documentation, and that the name of Sam
# Rushing not be used in advertising or publicity pertaining to
# distribution of the software without specific, written prior
# permission.
#
# SAM RUSHING DISCLAIMS ALL WARRANTIES WITH REGARD TO THIS SOFTWARE,
# INCLUDING ALL IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS, IN
# NO EVENT SHALL SAM RUSHING BE LIABLE FOR ANY SPECIAL, INDIRECT OR
# CONSEQUENTIAL DAMAGES OR ANY DAMAGES WHATSOEVER RESULTING FROM LOSS
# OF USE, DATA OR PROFITS, WHETHER IN AN ACTION OF CONTRACT,
# NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF OR IN
# CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
# ======================================================================
r"""A class supporting chat-style (command/response) protocols.
This class adds support for 'chat' style protocols - where one side
sends a 'command', and the other sends a response (examples would be
the common internet protocols - smtp, nntp, ftp, etc..).
The handle_read() method looks at the input stream for the current
'terminator' (usually '\r\n' for single-line responses, '\r\n.\r\n'
for multi-line output), calling self.found_terminator() on its
receipt.
for example:
Say you build an async nntp client using this class. At the start
of the connection, you'll have self.terminator set to '\r\n', in
order to process the single-line greeting. Just before issuing a
'LIST' command you'll set it to '\r\n.\r\n'. The output of the LIST
command will be accumulated (using your own 'collect_incoming_data'
method) up to the terminator, and then control will be returned to
you - by calling your self.found_terminator() method.
"""
import socket
import asyncore
from collections import deque
from sys import py3kwarning
from warnings import filterwarnings, catch_warnings
class async_chat (asyncore.dispatcher):
"""This is an abstract class. You must derive from this class, and add
the two methods collect_incoming_data() and found_terminator()"""
# these are overridable defaults
ac_in_buffer_size = 4096
ac_out_buffer_size = 4096
def __init__ (self, sock=None, map=None):
# for string terminator matching
self.ac_in_buffer = ''
# we use a list here rather than cStringIO for a few reasons...
# del lst[:] is faster than sio.truncate(0)
# lst = [] is faster than sio.truncate(0)
# cStringIO will be gaining unicode support in py3k, which
# will negatively affect the performance of bytes compared to
# a ''.join() equivalent
self.incoming = []
# we toss the use of the "simple producer" and replace it with
# a pure deque, which the original fifo was a wrapping of
self.producer_fifo = deque()
asyncore.dispatcher.__init__ (self, sock, map)
def collect_incoming_data(self, data):
raise NotImplementedError("must be implemented in subclass")
def _collect_incoming_data(self, data):
self.incoming.append(data)
def _get_data(self):
d = ''.join(self.incoming)
del self.incoming[:]
return d
def found_terminator(self):
raise NotImplementedError("must be implemented in subclass")
def set_terminator (self, term):
"Set the input delimiter. Can be a fixed string of any length, an integer, or None"
self.terminator = term
def get_terminator (self):
return self.terminator
# grab some more data from the socket,
# throw it to the collector method,
# check for the terminator,
# if found, transition to the next state.
def handle_read (self):
try:
data = self.recv (self.ac_in_buffer_size)
except socket.error, why:
self.handle_error()
return
self.ac_in_buffer = self.ac_in_buffer + data
# Continue to search for self.terminator in self.ac_in_buffer,
# while calling self.collect_incoming_data. The while loop
# is necessary because we might read several data+terminator
# combos with a single recv(4096).
while self.ac_in_buffer:
lb = len(self.ac_in_buffer)
terminator = self.get_terminator()
if not terminator:
# no terminator, collect it all
self.collect_incoming_data (self.ac_in_buffer)
self.ac_in_buffer = ''
elif isinstance(terminator, int) or isinstance(terminator, long):
# numeric terminator
n = terminator
if lb < n:
self.collect_incoming_data (self.ac_in_buffer)
self.ac_in_buffer = ''
self.terminator = self.terminator - lb
else:
self.collect_incoming_data (self.ac_in_buffer[:n])
self.ac_in_buffer = self.ac_in_buffer[n:]
self.terminator = 0
self.found_terminator()
else:
# 3 cases:
# 1) end of buffer matches terminator exactly:
# collect data, transition
# 2) end of buffer matches some prefix:
# collect data to the prefix
# 3) end of buffer does not match any prefix:
# collect data
terminator_len = len(terminator)
index = self.ac_in_buffer.find(terminator)
if index != -1:
# we found the terminator
if index > 0:
# don't bother reporting the empty string (source of subtle bugs)
self.collect_incoming_data (self.ac_in_buffer[:index])
self.ac_in_buffer = self.ac_in_buffer[index+terminator_len:]
# This does the Right Thing if the terminator is changed here.
self.found_terminator()
else:
# check for a prefix of the terminator
index = find_prefix_at_end (self.ac_in_buffer, terminator)
if index:
if index != lb:
# we found a prefix, collect up to the prefix
self.collect_incoming_data (self.ac_in_buffer[:-index])
self.ac_in_buffer = self.ac_in_buffer[-index:]
break
else:
# no prefix, collect it all
self.collect_incoming_data (self.ac_in_buffer)
self.ac_in_buffer = ''
def handle_write (self):
self.initiate_send()
def handle_close (self):
self.close()
def push (self, data):
sabs = self.ac_out_buffer_size
if len(data) > sabs:
for i in xrange(0, len(data), sabs):
self.producer_fifo.append(data[i:i+sabs])
else:
self.producer_fifo.append(data)
self.initiate_send()
def push_with_producer (self, producer):
self.producer_fifo.append(producer)
self.initiate_send()
def readable (self):
"predicate for inclusion in the readable for select()"
# cannot use the old predicate, it violates the claim of the
# set_terminator method.
# return (len(self.ac_in_buffer) <= self.ac_in_buffer_size)
return 1
def writable (self):
"predicate for inclusion in the writable for select()"
return self.producer_fifo or (not self.connected)
def close_when_done (self):
"automatically close this channel once the outgoing queue is empty"
self.producer_fifo.append(None)
def initiate_send(self):
while self.producer_fifo and self.connected:
first = self.producer_fifo[0]
# handle empty string/buffer or None entry
if not first:
del self.producer_fifo[0]
if first is None:
self.handle_close()
return
# handle classic producer behavior
obs = self.ac_out_buffer_size
try:
with catch_warnings():
if py3kwarning:
filterwarnings("ignore", ".*buffer", DeprecationWarning)
data = buffer(first, 0, obs)
except TypeError:
data = first.more()
if data:
self.producer_fifo.appendleft(data)
else:
del self.producer_fifo[0]
continue
# send the data
try:
num_sent = self.send(data)
except socket.error:
self.handle_error()
return
if num_sent:
if num_sent < len(data) or obs < len(first):
self.producer_fifo[0] = first[num_sent:]
else:
del self.producer_fifo[0]
# we tried to send some actual data
return
def discard_buffers (self):
# Emergencies only!
self.ac_in_buffer = ''
del self.incoming[:]
self.producer_fifo.clear()
class simple_producer:
def __init__ (self, data, buffer_size=512):
self.data = data
self.buffer_size = buffer_size
def more (self):
if len (self.data) > self.buffer_size:
result = self.data[:self.buffer_size]
self.data = self.data[self.buffer_size:]
return result
else:
result = self.data
self.data = ''
return result
class fifo:
def __init__ (self, list=None):
if not list:
self.list = deque()
else:
self.list = deque(list)
def __len__ (self):
return len(self.list)
def is_empty (self):
return not self.list
def first (self):
return self.list[0]
def push (self, data):
self.list.append(data)
def pop (self):
if self.list:
return (1, self.list.popleft())
else:
return (0, None)
# Given 'haystack', see if any prefix of 'needle' is at its end. This
# assumes an exact match has already been checked. Return the number of
# characters matched.
# for example:
# f_p_a_e ("qwerty\r", "\r\n") => 1
# f_p_a_e ("qwertydkjf", "\r\n") => 0
# f_p_a_e ("qwerty\r\n", "\r\n") => <undefined>
# this could maybe be made faster with a computed regex?
# [answer: no; circa Python-2.0, Jan 2001]
# new python: 28961/s
# old python: 18307/s
# re: 12820/s
# regex: 14035/s
def find_prefix_at_end (haystack, needle):
l = len(needle) - 1
while l and not haystack.endswith(needle[:l]):
l -= 1
return l

View File

@@ -0,0 +1,651 @@
# -*- Mode: Python -*-
# Id: asyncore.py,v 2.51 2000/09/07 22:29:26 rushing Exp
# Author: Sam Rushing <rushing@nightmare.com>
# ======================================================================
# Copyright 1996 by Sam Rushing
#
# All Rights Reserved
#
# Permission to use, copy, modify, and distribute this software and
# its documentation for any purpose and without fee is hereby
# granted, provided that the above copyright notice appear in all
# copies and that both that copyright notice and this permission
# notice appear in supporting documentation, and that the name of Sam
# Rushing not be used in advertising or publicity pertaining to
# distribution of the software without specific, written prior
# permission.
#
# SAM RUSHING DISCLAIMS ALL WARRANTIES WITH REGARD TO THIS SOFTWARE,
# INCLUDING ALL IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS, IN
# NO EVENT SHALL SAM RUSHING BE LIABLE FOR ANY SPECIAL, INDIRECT OR
# CONSEQUENTIAL DAMAGES OR ANY DAMAGES WHATSOEVER RESULTING FROM LOSS
# OF USE, DATA OR PROFITS, WHETHER IN AN ACTION OF CONTRACT,
# NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF OR IN
# CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
# ======================================================================
"""Basic infrastructure for asynchronous socket service clients and servers.
There are only two ways to have a program on a single processor do "more
than one thing at a time". Multi-threaded programming is the simplest and
most popular way to do it, but there is another very different technique,
that lets you have nearly all the advantages of multi-threading, without
actually using multiple threads. it's really only practical if your program
is largely I/O bound. If your program is CPU bound, then pre-emptive
scheduled threads are probably what you really need. Network servers are
rarely CPU-bound, however.
If your operating system supports the select() system call in its I/O
library (and nearly all do), then you can use it to juggle multiple
communication channels at once; doing other work while your I/O is taking
place in the "background." Although this strategy can seem strange and
complex, especially at first, it is in many ways easier to understand and
control than multi-threaded programming. The module documented here solves
many of the difficult problems for you, making the task of building
sophisticated high-performance network servers and clients a snap.
"""
import select
import socket
import sys
import time
import warnings
import os
from errno import EALREADY, EINPROGRESS, EWOULDBLOCK, ECONNRESET, EINVAL, \
ENOTCONN, ESHUTDOWN, EINTR, EISCONN, EBADF, ECONNABORTED, errorcode
try:
socket_map
except NameError:
socket_map = {}
def _strerror(err):
try:
return os.strerror(err)
except (ValueError, OverflowError, NameError):
if err in errorcode:
return errorcode[err]
return "Unknown error %s" %err
class ExitNow(Exception):
pass
_reraised_exceptions = (ExitNow, KeyboardInterrupt, SystemExit)
def read(obj):
try:
obj.handle_read_event()
except _reraised_exceptions:
raise
except:
obj.handle_error()
def write(obj):
try:
obj.handle_write_event()
except _reraised_exceptions:
raise
except:
obj.handle_error()
def _exception(obj):
try:
obj.handle_expt_event()
except _reraised_exceptions:
raise
except:
obj.handle_error()
def readwrite(obj, flags):
try:
if flags & select.POLLIN:
obj.handle_read_event()
if flags & select.POLLOUT:
obj.handle_write_event()
if flags & select.POLLPRI:
obj.handle_expt_event()
if flags & (select.POLLHUP | select.POLLERR | select.POLLNVAL):
obj.handle_close()
except socket.error, e:
if e.args[0] not in (EBADF, ECONNRESET, ENOTCONN, ESHUTDOWN, ECONNABORTED):
obj.handle_error()
else:
obj.handle_close()
except _reraised_exceptions:
raise
except:
obj.handle_error()
def poll(timeout=0.0, map=None):
if map is None:
map = socket_map
if map:
r = []; w = []; e = []
for fd, obj in map.items():
is_r = obj.readable()
is_w = obj.writable()
if is_r:
r.append(fd)
if is_w:
w.append(fd)
if is_r or is_w:
e.append(fd)
if [] == r == w == e:
time.sleep(timeout)
return
try:
r, w, e = select.select(r, w, e, timeout)
except select.error, err:
if err.args[0] != EINTR:
raise
else:
return
for fd in r:
obj = map.get(fd)
if obj is None:
continue
read(obj)
for fd in w:
obj = map.get(fd)
if obj is None:
continue
write(obj)
for fd in e:
obj = map.get(fd)
if obj is None:
continue
_exception(obj)
def poll2(timeout=0.0, map=None):
# Use the poll() support added to the select module in Python 2.0
if map is None:
map = socket_map
if timeout is not None:
# timeout is in milliseconds
timeout = int(timeout*1000)
pollster = select.poll()
if map:
for fd, obj in map.items():
flags = 0
if obj.readable():
flags |= select.POLLIN | select.POLLPRI
if obj.writable():
flags |= select.POLLOUT
if flags:
# Only check for exceptions if object was either readable
# or writable.
flags |= select.POLLERR | select.POLLHUP | select.POLLNVAL
pollster.register(fd, flags)
try:
r = pollster.poll(timeout)
except select.error, err:
if err.args[0] != EINTR:
raise
r = []
for fd, flags in r:
obj = map.get(fd)
if obj is None:
continue
readwrite(obj, flags)
poll3 = poll2 # Alias for backward compatibility
def loop(timeout=30.0, use_poll=False, map=None, count=None):
if map is None:
map = socket_map
if use_poll and hasattr(select, 'poll'):
poll_fun = poll2
else:
poll_fun = poll
if count is None:
while map:
poll_fun(timeout, map)
else:
while map and count > 0:
poll_fun(timeout, map)
count = count - 1
class dispatcher:
debug = False
connected = False
accepting = False
closing = False
addr = None
ignore_log_types = frozenset(['warning'])
def __init__(self, sock=None, map=None):
if map is None:
self._map = socket_map
else:
self._map = map
self._fileno = None
if sock:
# Set to nonblocking just to make sure for cases where we
# get a socket from a blocking source.
sock.setblocking(0)
self.set_socket(sock, map)
self.connected = True
# The constructor no longer requires that the socket
# passed be connected.
try:
self.addr = sock.getpeername()
except socket.error, err:
if err.args[0] == ENOTCONN:
# To handle the case where we got an unconnected
# socket.
self.connected = False
else:
# The socket is broken in some unknown way, alert
# the user and remove it from the map (to prevent
# polling of broken sockets).
self.del_channel(map)
raise
else:
self.socket = None
def __repr__(self):
status = [self.__class__.__module__+"."+self.__class__.__name__]
if self.accepting and self.addr:
status.append('listening')
elif self.connected:
status.append('connected')
if self.addr is not None:
try:
status.append('%s:%d' % self.addr)
except TypeError:
status.append(repr(self.addr))
return '<%s at %#x>' % (' '.join(status), id(self))
__str__ = __repr__
def add_channel(self, map=None):
#self.log_info('adding channel %s' % self)
if map is None:
map = self._map
map[self._fileno] = self
def del_channel(self, map=None):
fd = self._fileno
if map is None:
map = self._map
if fd in map:
#self.log_info('closing channel %d:%s' % (fd, self))
del map[fd]
self._fileno = None
def create_socket(self, family, type):
self.family_and_type = family, type
sock = socket.socket(family, type)
sock.setblocking(0)
self.set_socket(sock)
def set_socket(self, sock, map=None):
self.socket = sock
## self.__dict__['socket'] = sock
self._fileno = sock.fileno()
self.add_channel(map)
def set_reuse_addr(self):
# try to re-use a server port if possible
try:
self.socket.setsockopt(
socket.SOL_SOCKET, socket.SO_REUSEADDR,
self.socket.getsockopt(socket.SOL_SOCKET,
socket.SO_REUSEADDR) | 1
)
except socket.error:
pass
# ==================================================
# predicates for select()
# these are used as filters for the lists of sockets
# to pass to select().
# ==================================================
def readable(self):
return True
def writable(self):
return True
# ==================================================
# socket object methods.
# ==================================================
def listen(self, num):
self.accepting = True
if os.name == 'nt' and num > 5:
num = 5
return self.socket.listen(num)
def bind(self, addr):
self.addr = addr
return self.socket.bind(addr)
def connect(self, address):
self.connected = False
err = self.socket.connect_ex(address)
if err in (EINPROGRESS, EALREADY, EWOULDBLOCK) \
or err == EINVAL and os.name in ('nt', 'ce'):
return
if err in (0, EISCONN):
self.addr = address
self.handle_connect_event()
else:
raise socket.error(err, errorcode[err])
def accept(self):
# XXX can return either an address pair or None
try:
conn, addr = self.socket.accept()
except TypeError:
return None
except socket.error as why:
if why.args[0] in (EWOULDBLOCK, ECONNABORTED):
return None
else:
raise
else:
return conn, addr
def send(self, data):
try:
result = self.socket.send(data)
return result
except socket.error, why:
if why.args[0] == EWOULDBLOCK:
return 0
elif why.args[0] in (ECONNRESET, ENOTCONN, ESHUTDOWN, ECONNABORTED):
self.handle_close()
return 0
else:
raise
def recv(self, buffer_size):
try:
data = self.socket.recv(buffer_size)
if not data:
# a closed connection is indicated by signaling
# a read condition, and having recv() return 0.
self.handle_close()
return ''
else:
return data
except socket.error, why:
# winsock sometimes throws ENOTCONN
if why.args[0] in [ECONNRESET, ENOTCONN, ESHUTDOWN, ECONNABORTED]:
self.handle_close()
return ''
else:
raise
def close(self):
self.connected = False
self.accepting = False
self.del_channel()
try:
self.socket.close()
except socket.error, why:
if why.args[0] not in (ENOTCONN, EBADF):
raise
# cheap inheritance, used to pass all other attribute
# references to the underlying socket object.
def __getattr__(self, attr):
try:
retattr = getattr(self.socket, attr)
except AttributeError:
raise AttributeError("%s instance has no attribute '%s'"
%(self.__class__.__name__, attr))
else:
msg = "%(me)s.%(attr)s is deprecated. Use %(me)s.socket.%(attr)s " \
"instead." % {'me': self.__class__.__name__, 'attr':attr}
warnings.warn(msg, DeprecationWarning, stacklevel=2)
return retattr
# log and log_info may be overridden to provide more sophisticated
# logging and warning methods. In general, log is for 'hit' logging
# and 'log_info' is for informational, warning and error logging.
def log(self, message):
sys.stderr.write('log: %s\n' % str(message))
def log_info(self, message, type='info'):
if type not in self.ignore_log_types:
print '%s: %s' % (type, message)
def handle_read_event(self):
if self.accepting:
# accepting sockets are never connected, they "spawn" new
# sockets that are connected
self.handle_accept()
elif not self.connected:
self.handle_connect_event()
self.handle_read()
else:
self.handle_read()
def handle_connect_event(self):
err = self.socket.getsockopt(socket.SOL_SOCKET, socket.SO_ERROR)
if err != 0:
raise socket.error(err, _strerror(err))
self.handle_connect()
self.connected = True
def handle_write_event(self):
if self.accepting:
# Accepting sockets shouldn't get a write event.
# We will pretend it didn't happen.
return
if not self.connected:
#check for errors
err = self.socket.getsockopt(socket.SOL_SOCKET, socket.SO_ERROR)
if err != 0:
raise socket.error(err, _strerror(err))
self.handle_connect_event()
self.handle_write()
def handle_expt_event(self):
# handle_expt_event() is called if there might be an error on the
# socket, or if there is OOB data
# check for the error condition first
err = self.socket.getsockopt(socket.SOL_SOCKET, socket.SO_ERROR)
if err != 0:
# we can get here when select.select() says that there is an
# exceptional condition on the socket
# since there is an error, we'll go ahead and close the socket
# like we would in a subclassed handle_read() that received no
# data
self.handle_close()
else:
self.handle_expt()
def handle_error(self):
nil, t, v, tbinfo = compact_traceback()
# sometimes a user repr method will crash.
try:
self_repr = repr(self)
except:
self_repr = '<__repr__(self) failed for object at %0x>' % id(self)
self.log_info(
'uncaptured python exception, closing channel %s (%s:%s %s)' % (
self_repr,
t,
v,
tbinfo
),
'error'
)
self.handle_close()
def handle_expt(self):
self.log_info('unhandled incoming priority event', 'warning')
def handle_read(self):
self.log_info('unhandled read event', 'warning')
def handle_write(self):
self.log_info('unhandled write event', 'warning')
def handle_connect(self):
self.log_info('unhandled connect event', 'warning')
def handle_accept(self):
self.log_info('unhandled accept event', 'warning')
def handle_close(self):
self.log_info('unhandled close event', 'warning')
self.close()
# ---------------------------------------------------------------------------
# adds simple buffered output capability, useful for simple clients.
# [for more sophisticated usage use asynchat.async_chat]
# ---------------------------------------------------------------------------
class dispatcher_with_send(dispatcher):
def __init__(self, sock=None, map=None):
dispatcher.__init__(self, sock, map)
self.out_buffer = ''
def initiate_send(self):
num_sent = 0
num_sent = dispatcher.send(self, self.out_buffer[:512])
self.out_buffer = self.out_buffer[num_sent:]
def handle_write(self):
self.initiate_send()
def writable(self):
return (not self.connected) or len(self.out_buffer)
def send(self, data):
if self.debug:
self.log_info('sending %s' % repr(data))
self.out_buffer = self.out_buffer + data
self.initiate_send()
# ---------------------------------------------------------------------------
# used for debugging.
# ---------------------------------------------------------------------------
def compact_traceback():
t, v, tb = sys.exc_info()
tbinfo = []
if not tb: # Must have a traceback
raise AssertionError("traceback does not exist")
while tb:
tbinfo.append((
tb.tb_frame.f_code.co_filename,
tb.tb_frame.f_code.co_name,
str(tb.tb_lineno)
))
tb = tb.tb_next
# just to be safe
del tb
file, function, line = tbinfo[-1]
info = ' '.join(['[%s|%s|%s]' % x for x in tbinfo])
return (file, function, line), t, v, info
def close_all(map=None, ignore_all=False):
if map is None:
map = socket_map
for x in map.values():
try:
x.close()
except OSError, x:
if x.args[0] == EBADF:
pass
elif not ignore_all:
raise
except _reraised_exceptions:
raise
except:
if not ignore_all:
raise
map.clear()
# Asynchronous File I/O:
#
# After a little research (reading man pages on various unixen, and
# digging through the linux kernel), I've determined that select()
# isn't meant for doing asynchronous file i/o.
# Heartening, though - reading linux/mm/filemap.c shows that linux
# supports asynchronous read-ahead. So _MOST_ of the time, the data
# will be sitting in memory for us already when we go to read it.
#
# What other OS's (besides NT) support async file i/o? [VMS?]
#
# Regardless, this is useful for pipes, and stdin/stdout...
if os.name == 'posix':
import fcntl
class file_wrapper:
# Here we override just enough to make a file
# look like a socket for the purposes of asyncore.
# The passed fd is automatically os.dup()'d
def __init__(self, fd):
self.fd = os.dup(fd)
def recv(self, *args):
return os.read(self.fd, *args)
def send(self, *args):
return os.write(self.fd, *args)
def getsockopt(self, level, optname, buflen=None):
if (level == socket.SOL_SOCKET and
optname == socket.SO_ERROR and
not buflen):
return 0
raise NotImplementedError("Only asyncore specific behaviour "
"implemented.")
read = recv
write = send
def close(self):
os.close(self.fd)
def fileno(self):
return self.fd
class file_dispatcher(dispatcher):
def __init__(self, fd, map=None):
dispatcher.__init__(self, None, map)
self.connected = True
try:
fd = fd.fileno()
except AttributeError:
pass
self.set_file(fd)
# set it to non-blocking mode
flags = fcntl.fcntl(fd, fcntl.F_GETFL, 0)
flags = flags | os.O_NONBLOCK
fcntl.fcntl(fd, fcntl.F_SETFL, flags)
def set_file(self, fd):
self.socket = file_wrapper(fd)
self._fileno = self.socket.fileno()
self.add_channel()

View File

@@ -0,0 +1,65 @@
"""
atexit.py - allow programmer to define multiple exit functions to be executed
upon normal program termination.
One public function, register, is defined.
"""
__all__ = ["register"]
import sys
_exithandlers = []
def _run_exitfuncs():
"""run any registered exit functions
_exithandlers is traversed in reverse order so functions are executed
last in, first out.
"""
exc_info = None
while _exithandlers:
func, targs, kargs = _exithandlers.pop()
try:
func(*targs, **kargs)
except SystemExit:
exc_info = sys.exc_info()
except:
import traceback
print >> sys.stderr, "Error in atexit._run_exitfuncs:"
traceback.print_exc()
exc_info = sys.exc_info()
if exc_info is not None:
raise exc_info[0], exc_info[1], exc_info[2]
def register(func, *targs, **kargs):
"""register a function to be executed upon normal program termination
func - function to be called at exit
targs - optional arguments to pass to func
kargs - optional keyword arguments to pass to func
func is returned to facilitate usage as a decorator.
"""
_exithandlers.append((func, targs, kargs))
return func
if hasattr(sys, "exitfunc"):
# Assume it's another registered exit function - append it to our list
register(sys.exitfunc)
sys.exitfunc = _run_exitfuncs
if __name__ == "__main__":
def x1():
print "running x1"
def x2(n):
print "running x2(%r)" % (n,)
def x3(n, kwd=None):
print "running x3(%r, kwd=%r)" % (n, kwd)
register(x1)
register(x2, 12)
register(x3, 5, "bar")
register(x3, "no kwd args")

View File

@@ -0,0 +1,260 @@
"""Classes for manipulating audio devices (currently only for Sun and SGI)"""
from warnings import warnpy3k
warnpy3k("the audiodev module has been removed in Python 3.0", stacklevel=2)
del warnpy3k
__all__ = ["error","AudioDev"]
class error(Exception):
pass
class Play_Audio_sgi:
# Private instance variables
## if 0: access frameratelist, nchannelslist, sampwidthlist, oldparams, \
## params, config, inited_outrate, inited_width, \
## inited_nchannels, port, converter, classinited: private
classinited = 0
frameratelist = nchannelslist = sampwidthlist = None
def initclass(self):
import AL
self.frameratelist = [
(48000, AL.RATE_48000),
(44100, AL.RATE_44100),
(32000, AL.RATE_32000),
(22050, AL.RATE_22050),
(16000, AL.RATE_16000),
(11025, AL.RATE_11025),
( 8000, AL.RATE_8000),
]
self.nchannelslist = [
(1, AL.MONO),
(2, AL.STEREO),
(4, AL.QUADRO),
]
self.sampwidthlist = [
(1, AL.SAMPLE_8),
(2, AL.SAMPLE_16),
(3, AL.SAMPLE_24),
]
self.classinited = 1
def __init__(self):
import al, AL
if not self.classinited:
self.initclass()
self.oldparams = []
self.params = [AL.OUTPUT_RATE, 0]
self.config = al.newconfig()
self.inited_outrate = 0
self.inited_width = 0
self.inited_nchannels = 0
self.converter = None
self.port = None
return
def __del__(self):
if self.port:
self.stop()
if self.oldparams:
import al, AL
al.setparams(AL.DEFAULT_DEVICE, self.oldparams)
self.oldparams = []
def wait(self):
if not self.port:
return
import time
while self.port.getfilled() > 0:
time.sleep(0.1)
self.stop()
def stop(self):
if self.port:
self.port.closeport()
self.port = None
if self.oldparams:
import al, AL
al.setparams(AL.DEFAULT_DEVICE, self.oldparams)
self.oldparams = []
def setoutrate(self, rate):
for (raw, cooked) in self.frameratelist:
if rate == raw:
self.params[1] = cooked
self.inited_outrate = 1
break
else:
raise error, 'bad output rate'
def setsampwidth(self, width):
for (raw, cooked) in self.sampwidthlist:
if width == raw:
self.config.setwidth(cooked)
self.inited_width = 1
break
else:
if width == 0:
import AL
self.inited_width = 0
self.config.setwidth(AL.SAMPLE_16)
self.converter = self.ulaw2lin
else:
raise error, 'bad sample width'
def setnchannels(self, nchannels):
for (raw, cooked) in self.nchannelslist:
if nchannels == raw:
self.config.setchannels(cooked)
self.inited_nchannels = 1
break
else:
raise error, 'bad # of channels'
def writeframes(self, data):
if not (self.inited_outrate and self.inited_nchannels):
raise error, 'params not specified'
if not self.port:
import al, AL
self.port = al.openport('Python', 'w', self.config)
self.oldparams = self.params[:]
al.getparams(AL.DEFAULT_DEVICE, self.oldparams)
al.setparams(AL.DEFAULT_DEVICE, self.params)
if self.converter:
data = self.converter(data)
self.port.writesamps(data)
def getfilled(self):
if self.port:
return self.port.getfilled()
else:
return 0
def getfillable(self):
if self.port:
return self.port.getfillable()
else:
return self.config.getqueuesize()
# private methods
## if 0: access *: private
def ulaw2lin(self, data):
import audioop
return audioop.ulaw2lin(data, 2)
class Play_Audio_sun:
## if 0: access outrate, sampwidth, nchannels, inited_outrate, inited_width, \
## inited_nchannels, converter: private
def __init__(self):
self.outrate = 0
self.sampwidth = 0
self.nchannels = 0
self.inited_outrate = 0
self.inited_width = 0
self.inited_nchannels = 0
self.converter = None
self.port = None
return
def __del__(self):
self.stop()
def setoutrate(self, rate):
self.outrate = rate
self.inited_outrate = 1
def setsampwidth(self, width):
self.sampwidth = width
self.inited_width = 1
def setnchannels(self, nchannels):
self.nchannels = nchannels
self.inited_nchannels = 1
def writeframes(self, data):
if not (self.inited_outrate and self.inited_width and self.inited_nchannels):
raise error, 'params not specified'
if not self.port:
import sunaudiodev, SUNAUDIODEV
self.port = sunaudiodev.open('w')
info = self.port.getinfo()
info.o_sample_rate = self.outrate
info.o_channels = self.nchannels
if self.sampwidth == 0:
info.o_precision = 8
self.o_encoding = SUNAUDIODEV.ENCODING_ULAW
# XXX Hack, hack -- leave defaults
else:
info.o_precision = 8 * self.sampwidth
info.o_encoding = SUNAUDIODEV.ENCODING_LINEAR
self.port.setinfo(info)
if self.converter:
data = self.converter(data)
self.port.write(data)
def wait(self):
if not self.port:
return
self.port.drain()
self.stop()
def stop(self):
if self.port:
self.port.flush()
self.port.close()
self.port = None
def getfilled(self):
if self.port:
return self.port.obufcount()
else:
return 0
## # Nobody remembers what this method does, and it's broken. :-(
## def getfillable(self):
## return BUFFERSIZE - self.getfilled()
def AudioDev():
# Dynamically try to import and use a platform specific module.
try:
import al
except ImportError:
try:
import sunaudiodev
return Play_Audio_sun()
except ImportError:
try:
import Audio_mac
except ImportError:
raise error, 'no audio device'
else:
return Audio_mac.Play_Audio_mac()
else:
return Play_Audio_sgi()
def test(fn = None):
import sys
if sys.argv[1:]:
fn = sys.argv[1]
else:
fn = 'f:just samples:just.aif'
import aifc
af = aifc.open(fn, 'r')
print fn, af.getparams()
p = AudioDev()
p.setoutrate(af.getframerate())
p.setsampwidth(af.getsampwidth())
p.setnchannels(af.getnchannels())
BUFSIZ = af.getframerate()/af.getsampwidth()/af.getnchannels()
while 1:
data = af.readframes(BUFSIZ)
if not data: break
print len(data)
p.writeframes(data)
p.wait()
if __name__ == '__main__':
test()

360
src/python/stdlib/base64.py Executable file
View File

@@ -0,0 +1,360 @@
#! /usr/bin/env python
"""RFC 3548: Base16, Base32, Base64 Data Encodings"""
# Modified 04-Oct-1995 by Jack Jansen to use binascii module
# Modified 30-Dec-2003 by Barry Warsaw to add full RFC 3548 support
import re
import struct
import binascii
__all__ = [
# Legacy interface exports traditional RFC 1521 Base64 encodings
'encode', 'decode', 'encodestring', 'decodestring',
# Generalized interface for other encodings
'b64encode', 'b64decode', 'b32encode', 'b32decode',
'b16encode', 'b16decode',
# Standard Base64 encoding
'standard_b64encode', 'standard_b64decode',
# Some common Base64 alternatives. As referenced by RFC 3458, see thread
# starting at:
#
# http://zgp.org/pipermail/p2p-hackers/2001-September/000316.html
'urlsafe_b64encode', 'urlsafe_b64decode',
]
_translation = [chr(_x) for _x in range(256)]
EMPTYSTRING = ''
def _translate(s, altchars):
translation = _translation[:]
for k, v in altchars.items():
translation[ord(k)] = v
return s.translate(''.join(translation))
# Base64 encoding/decoding uses binascii
def b64encode(s, altchars=None):
"""Encode a string using Base64.
s is the string to encode. Optional altchars must be a string of at least
length 2 (additional characters are ignored) which specifies an
alternative alphabet for the '+' and '/' characters. This allows an
application to e.g. generate url or filesystem safe Base64 strings.
The encoded string is returned.
"""
# Strip off the trailing newline
encoded = binascii.b2a_base64(s)[:-1]
if altchars is not None:
return _translate(encoded, {'+': altchars[0], '/': altchars[1]})
return encoded
def b64decode(s, altchars=None):
"""Decode a Base64 encoded string.
s is the string to decode. Optional altchars must be a string of at least
length 2 (additional characters are ignored) which specifies the
alternative alphabet used instead of the '+' and '/' characters.
The decoded string is returned. A TypeError is raised if s were
incorrectly padded or if there are non-alphabet characters present in the
string.
"""
if altchars is not None:
s = _translate(s, {altchars[0]: '+', altchars[1]: '/'})
try:
return binascii.a2b_base64(s)
except binascii.Error, msg:
# Transform this exception for consistency
raise TypeError(msg)
def standard_b64encode(s):
"""Encode a string using the standard Base64 alphabet.
s is the string to encode. The encoded string is returned.
"""
return b64encode(s)
def standard_b64decode(s):
"""Decode a string encoded with the standard Base64 alphabet.
s is the string to decode. The decoded string is returned. A TypeError
is raised if the string is incorrectly padded or if there are non-alphabet
characters present in the string.
"""
return b64decode(s)
def urlsafe_b64encode(s):
"""Encode a string using a url-safe Base64 alphabet.
s is the string to encode. The encoded string is returned. The alphabet
uses '-' instead of '+' and '_' instead of '/'.
"""
return b64encode(s, '-_')
def urlsafe_b64decode(s):
"""Decode a string encoded with the standard Base64 alphabet.
s is the string to decode. The decoded string is returned. A TypeError
is raised if the string is incorrectly padded or if there are non-alphabet
characters present in the string.
The alphabet uses '-' instead of '+' and '_' instead of '/'.
"""
return b64decode(s, '-_')
# Base32 encoding/decoding must be done in Python
_b32alphabet = {
0: 'A', 9: 'J', 18: 'S', 27: '3',
1: 'B', 10: 'K', 19: 'T', 28: '4',
2: 'C', 11: 'L', 20: 'U', 29: '5',
3: 'D', 12: 'M', 21: 'V', 30: '6',
4: 'E', 13: 'N', 22: 'W', 31: '7',
5: 'F', 14: 'O', 23: 'X',
6: 'G', 15: 'P', 24: 'Y',
7: 'H', 16: 'Q', 25: 'Z',
8: 'I', 17: 'R', 26: '2',
}
_b32tab = _b32alphabet.items()
_b32tab.sort()
_b32tab = [v for k, v in _b32tab]
_b32rev = dict([(v, long(k)) for k, v in _b32alphabet.items()])
def b32encode(s):
"""Encode a string using Base32.
s is the string to encode. The encoded string is returned.
"""
parts = []
quanta, leftover = divmod(len(s), 5)
# Pad the last quantum with zero bits if necessary
if leftover:
s += ('\0' * (5 - leftover))
quanta += 1
for i in range(quanta):
# c1 and c2 are 16 bits wide, c3 is 8 bits wide. The intent of this
# code is to process the 40 bits in units of 5 bits. So we take the 1
# leftover bit of c1 and tack it onto c2. Then we take the 2 leftover
# bits of c2 and tack them onto c3. The shifts and masks are intended
# to give us values of exactly 5 bits in width.
c1, c2, c3 = struct.unpack('!HHB', s[i*5:(i+1)*5])
c2 += (c1 & 1) << 16 # 17 bits wide
c3 += (c2 & 3) << 8 # 10 bits wide
parts.extend([_b32tab[c1 >> 11], # bits 1 - 5
_b32tab[(c1 >> 6) & 0x1f], # bits 6 - 10
_b32tab[(c1 >> 1) & 0x1f], # bits 11 - 15
_b32tab[c2 >> 12], # bits 16 - 20 (1 - 5)
_b32tab[(c2 >> 7) & 0x1f], # bits 21 - 25 (6 - 10)
_b32tab[(c2 >> 2) & 0x1f], # bits 26 - 30 (11 - 15)
_b32tab[c3 >> 5], # bits 31 - 35 (1 - 5)
_b32tab[c3 & 0x1f], # bits 36 - 40 (1 - 5)
])
encoded = EMPTYSTRING.join(parts)
# Adjust for any leftover partial quanta
if leftover == 1:
return encoded[:-6] + '======'
elif leftover == 2:
return encoded[:-4] + '===='
elif leftover == 3:
return encoded[:-3] + '==='
elif leftover == 4:
return encoded[:-1] + '='
return encoded
def b32decode(s, casefold=False, map01=None):
"""Decode a Base32 encoded string.
s is the string to decode. Optional casefold is a flag specifying whether
a lowercase alphabet is acceptable as input. For security purposes, the
default is False.
RFC 3548 allows for optional mapping of the digit 0 (zero) to the letter O
(oh), and for optional mapping of the digit 1 (one) to either the letter I
(eye) or letter L (el). The optional argument map01 when not None,
specifies which letter the digit 1 should be mapped to (when map01 is not
None, the digit 0 is always mapped to the letter O). For security
purposes the default is None, so that 0 and 1 are not allowed in the
input.
The decoded string is returned. A TypeError is raised if s were
incorrectly padded or if there are non-alphabet characters present in the
string.
"""
quanta, leftover = divmod(len(s), 8)
if leftover:
raise TypeError('Incorrect padding')
# Handle section 2.4 zero and one mapping. The flag map01 will be either
# False, or the character to map the digit 1 (one) to. It should be
# either L (el) or I (eye).
if map01:
s = _translate(s, {'0': 'O', '1': map01})
if casefold:
s = s.upper()
# Strip off pad characters from the right. We need to count the pad
# characters because this will tell us how many null bytes to remove from
# the end of the decoded string.
padchars = 0
mo = re.search('(?P<pad>[=]*)$', s)
if mo:
padchars = len(mo.group('pad'))
if padchars > 0:
s = s[:-padchars]
# Now decode the full quanta
parts = []
acc = 0
shift = 35
for c in s:
val = _b32rev.get(c)
if val is None:
raise TypeError('Non-base32 digit found')
acc += _b32rev[c] << shift
shift -= 5
if shift < 0:
parts.append(binascii.unhexlify('%010x' % acc))
acc = 0
shift = 35
# Process the last, partial quanta
last = binascii.unhexlify('%010x' % acc)
if padchars == 0:
last = '' # No characters
elif padchars == 1:
last = last[:-1]
elif padchars == 3:
last = last[:-2]
elif padchars == 4:
last = last[:-3]
elif padchars == 6:
last = last[:-4]
else:
raise TypeError('Incorrect padding')
parts.append(last)
return EMPTYSTRING.join(parts)
# RFC 3548, Base 16 Alphabet specifies uppercase, but hexlify() returns
# lowercase. The RFC also recommends against accepting input case
# insensitively.
def b16encode(s):
"""Encode a string using Base16.
s is the string to encode. The encoded string is returned.
"""
return binascii.hexlify(s).upper()
def b16decode(s, casefold=False):
"""Decode a Base16 encoded string.
s is the string to decode. Optional casefold is a flag specifying whether
a lowercase alphabet is acceptable as input. For security purposes, the
default is False.
The decoded string is returned. A TypeError is raised if s were
incorrectly padded or if there are non-alphabet characters present in the
string.
"""
if casefold:
s = s.upper()
if re.search('[^0-9A-F]', s):
raise TypeError('Non-base16 digit found')
return binascii.unhexlify(s)
# Legacy interface. This code could be cleaned up since I don't believe
# binascii has any line length limitations. It just doesn't seem worth it
# though.
MAXLINESIZE = 76 # Excluding the CRLF
MAXBINSIZE = (MAXLINESIZE//4)*3
def encode(input, output):
"""Encode a file."""
while True:
s = input.read(MAXBINSIZE)
if not s:
break
while len(s) < MAXBINSIZE:
ns = input.read(MAXBINSIZE-len(s))
if not ns:
break
s += ns
line = binascii.b2a_base64(s)
output.write(line)
def decode(input, output):
"""Decode a file."""
while True:
line = input.readline()
if not line:
break
s = binascii.a2b_base64(line)
output.write(s)
def encodestring(s):
"""Encode a string into multiple lines of base-64 data."""
pieces = []
for i in range(0, len(s), MAXBINSIZE):
chunk = s[i : i + MAXBINSIZE]
pieces.append(binascii.b2a_base64(chunk))
return "".join(pieces)
def decodestring(s):
"""Decode a string."""
return binascii.a2b_base64(s)
# Useable as a script...
def test():
"""Small test program"""
import sys, getopt
try:
opts, args = getopt.getopt(sys.argv[1:], 'deut')
except getopt.error, msg:
sys.stdout = sys.stderr
print msg
print """usage: %s [-d|-e|-u|-t] [file|-]
-d, -u: decode
-e: encode (default)
-t: encode and decode string 'Aladdin:open sesame'"""%sys.argv[0]
sys.exit(2)
func = encode
for o, a in opts:
if o == '-e': func = encode
if o == '-d': func = decode
if o == '-u': func = decode
if o == '-t': test1(); return
if args and args[0] != '-':
with open(args[0], 'rb') as f:
func(f, sys.stdout)
else:
func(sys.stdin, sys.stdout)
def test1():
s0 = "Aladdin:open sesame"
s1 = encodestring(s0)
s2 = decodestring(s1)
print s0, repr(s1), s2
if __name__ == '__main__':
test()

628
src/python/stdlib/bdb.py Normal file
View File

@@ -0,0 +1,628 @@
"""Debugger basics"""
import fnmatch
import sys
import os
import types
__all__ = ["BdbQuit","Bdb","Breakpoint"]
class BdbQuit(Exception):
"""Exception to give up completely"""
class Bdb:
"""Generic Python debugger base class.
This class takes care of details of the trace facility;
a derived class should implement user interaction.
The standard debugger class (pdb.Pdb) is an example.
"""
def __init__(self, skip=None):
self.skip = set(skip) if skip else None
self.breaks = {}
self.fncache = {}
def canonic(self, filename):
if filename == "<" + filename[1:-1] + ">":
return filename
canonic = self.fncache.get(filename)
if not canonic:
canonic = os.path.abspath(filename)
canonic = os.path.normcase(canonic)
self.fncache[filename] = canonic
return canonic
def reset(self):
import linecache
linecache.checkcache()
self.botframe = None
self._set_stopinfo(None, None)
def trace_dispatch(self, frame, event, arg):
if self.quitting:
return # None
if event == 'line':
return self.dispatch_line(frame)
if event == 'call':
return self.dispatch_call(frame, arg)
if event == 'return':
return self.dispatch_return(frame, arg)
if event == 'exception':
return self.dispatch_exception(frame, arg)
if event == 'c_call':
return self.trace_dispatch
if event == 'c_exception':
return self.trace_dispatch
if event == 'c_return':
return self.trace_dispatch
print 'bdb.Bdb.dispatch: unknown debugging event:', repr(event)
return self.trace_dispatch
def dispatch_line(self, frame):
if self.stop_here(frame) or self.break_here(frame):
self.user_line(frame)
if self.quitting: raise BdbQuit
return self.trace_dispatch
def dispatch_call(self, frame, arg):
# XXX 'arg' is no longer used
if self.botframe is None:
# First call of dispatch since reset()
self.botframe = frame.f_back # (CT) Note that this may also be None!
return self.trace_dispatch
if not (self.stop_here(frame) or self.break_anywhere(frame)):
# No need to trace this function
return # None
self.user_call(frame, arg)
if self.quitting: raise BdbQuit
return self.trace_dispatch
def dispatch_return(self, frame, arg):
if self.stop_here(frame) or frame == self.returnframe:
self.user_return(frame, arg)
if self.quitting: raise BdbQuit
return self.trace_dispatch
def dispatch_exception(self, frame, arg):
if self.stop_here(frame):
self.user_exception(frame, arg)
if self.quitting: raise BdbQuit
return self.trace_dispatch
# Normally derived classes don't override the following
# methods, but they may if they want to redefine the
# definition of stopping and breakpoints.
def is_skipped_module(self, module_name):
for pattern in self.skip:
if fnmatch.fnmatch(module_name, pattern):
return True
return False
def stop_here(self, frame):
# (CT) stopframe may now also be None, see dispatch_call.
# (CT) the former test for None is therefore removed from here.
if self.skip and \
self.is_skipped_module(frame.f_globals.get('__name__')):
return False
if frame is self.stopframe:
if self.stoplineno == -1:
return False
return frame.f_lineno >= self.stoplineno
while frame is not None and frame is not self.stopframe:
if frame is self.botframe:
return True
frame = frame.f_back
return False
def break_here(self, frame):
filename = self.canonic(frame.f_code.co_filename)
if not filename in self.breaks:
return False
lineno = frame.f_lineno
if not lineno in self.breaks[filename]:
# The line itself has no breakpoint, but maybe the line is the
# first line of a function with breakpoint set by function name.
lineno = frame.f_code.co_firstlineno
if not lineno in self.breaks[filename]:
return False
# flag says ok to delete temp. bp
(bp, flag) = effective(filename, lineno, frame)
if bp:
self.currentbp = bp.number
if (flag and bp.temporary):
self.do_clear(str(bp.number))
return True
else:
return False
def do_clear(self, arg):
raise NotImplementedError, "subclass of bdb must implement do_clear()"
def break_anywhere(self, frame):
return self.canonic(frame.f_code.co_filename) in self.breaks
# Derived classes should override the user_* methods
# to gain control.
def user_call(self, frame, argument_list):
"""This method is called when there is the remote possibility
that we ever need to stop in this function."""
pass
def user_line(self, frame):
"""This method is called when we stop or break at this line."""
pass
def user_return(self, frame, return_value):
"""This method is called when a return trap is set here."""
pass
def user_exception(self, frame, exc_info):
exc_type, exc_value, exc_traceback = exc_info
"""This method is called if an exception occurs,
but only if we are to stop at or just below this level."""
pass
def _set_stopinfo(self, stopframe, returnframe, stoplineno=0):
self.stopframe = stopframe
self.returnframe = returnframe
self.quitting = 0
# stoplineno >= 0 means: stop at line >= the stoplineno
# stoplineno -1 means: don't stop at all
self.stoplineno = stoplineno
# Derived classes and clients can call the following methods
# to affect the stepping state.
def set_until(self, frame): #the name "until" is borrowed from gdb
"""Stop when the line with the line no greater than the current one is
reached or when returning from current frame"""
self._set_stopinfo(frame, frame, frame.f_lineno+1)
def set_step(self):
"""Stop after one line of code."""
self._set_stopinfo(None, None)
def set_next(self, frame):
"""Stop on the next line in or below the given frame."""
self._set_stopinfo(frame, None)
def set_return(self, frame):
"""Stop when returning from the given frame."""
self._set_stopinfo(frame.f_back, frame)
def set_trace(self, frame=None):
"""Start debugging from `frame`.
If frame is not specified, debugging starts from caller's frame.
"""
if frame is None:
frame = sys._getframe().f_back
self.reset()
while frame:
frame.f_trace = self.trace_dispatch
self.botframe = frame
frame = frame.f_back
self.set_step()
sys.settrace(self.trace_dispatch)
def set_continue(self):
# Don't stop except at breakpoints or when finished
self._set_stopinfo(self.botframe, None, -1)
if not self.breaks:
# no breakpoints; run without debugger overhead
sys.settrace(None)
frame = sys._getframe().f_back
while frame and frame is not self.botframe:
del frame.f_trace
frame = frame.f_back
def set_quit(self):
self.stopframe = self.botframe
self.returnframe = None
self.quitting = 1
sys.settrace(None)
# Derived classes and clients can call the following methods
# to manipulate breakpoints. These methods return an
# error message is something went wrong, None if all is well.
# Set_break prints out the breakpoint line and file:lineno.
# Call self.get_*break*() to see the breakpoints or better
# for bp in Breakpoint.bpbynumber: if bp: bp.bpprint().
def set_break(self, filename, lineno, temporary=0, cond = None,
funcname=None):
filename = self.canonic(filename)
import linecache # Import as late as possible
line = linecache.getline(filename, lineno)
if not line:
return 'Line %s:%d does not exist' % (filename,
lineno)
if not filename in self.breaks:
self.breaks[filename] = []
list = self.breaks[filename]
if not lineno in list:
list.append(lineno)
bp = Breakpoint(filename, lineno, temporary, cond, funcname)
def clear_break(self, filename, lineno):
filename = self.canonic(filename)
if not filename in self.breaks:
return 'There are no breakpoints in %s' % filename
if lineno not in self.breaks[filename]:
return 'There is no breakpoint at %s:%d' % (filename,
lineno)
# If there's only one bp in the list for that file,line
# pair, then remove the breaks entry
for bp in Breakpoint.bplist[filename, lineno][:]:
bp.deleteMe()
if (filename, lineno) not in Breakpoint.bplist:
self.breaks[filename].remove(lineno)
if not self.breaks[filename]:
del self.breaks[filename]
def clear_bpbynumber(self, arg):
try:
number = int(arg)
except:
return 'Non-numeric breakpoint number (%s)' % arg
try:
bp = Breakpoint.bpbynumber[number]
except IndexError:
return 'Breakpoint number (%d) out of range' % number
if not bp:
return 'Breakpoint (%d) already deleted' % number
self.clear_break(bp.file, bp.line)
def clear_all_file_breaks(self, filename):
filename = self.canonic(filename)
if not filename in self.breaks:
return 'There are no breakpoints in %s' % filename
for line in self.breaks[filename]:
blist = Breakpoint.bplist[filename, line]
for bp in blist:
bp.deleteMe()
del self.breaks[filename]
def clear_all_breaks(self):
if not self.breaks:
return 'There are no breakpoints'
for bp in Breakpoint.bpbynumber:
if bp:
bp.deleteMe()
self.breaks = {}
def get_break(self, filename, lineno):
filename = self.canonic(filename)
return filename in self.breaks and \
lineno in self.breaks[filename]
def get_breaks(self, filename, lineno):
filename = self.canonic(filename)
return filename in self.breaks and \
lineno in self.breaks[filename] and \
Breakpoint.bplist[filename, lineno] or []
def get_file_breaks(self, filename):
filename = self.canonic(filename)
if filename in self.breaks:
return self.breaks[filename]
else:
return []
def get_all_breaks(self):
return self.breaks
# Derived classes and clients can call the following method
# to get a data structure representing a stack trace.
def get_stack(self, f, t):
stack = []
if t and t.tb_frame is f:
t = t.tb_next
while f is not None:
stack.append((f, f.f_lineno))
if f is self.botframe:
break
f = f.f_back
stack.reverse()
i = max(0, len(stack) - 1)
while t is not None:
stack.append((t.tb_frame, t.tb_lineno))
t = t.tb_next
if f is None:
i = max(0, len(stack) - 1)
return stack, i
#
def format_stack_entry(self, frame_lineno, lprefix=': '):
import linecache, repr
frame, lineno = frame_lineno
filename = self.canonic(frame.f_code.co_filename)
s = '%s(%r)' % (filename, lineno)
if frame.f_code.co_name:
s = s + frame.f_code.co_name
else:
s = s + "<lambda>"
if '__args__' in frame.f_locals:
args = frame.f_locals['__args__']
else:
args = None
if args:
s = s + repr.repr(args)
else:
s = s + '()'
if '__return__' in frame.f_locals:
rv = frame.f_locals['__return__']
s = s + '->'
s = s + repr.repr(rv)
line = linecache.getline(filename, lineno, frame.f_globals)
if line: s = s + lprefix + line.strip()
return s
# The following two methods can be called by clients to use
# a debugger to debug a statement, given as a string.
def run(self, cmd, globals=None, locals=None):
if globals is None:
import __main__
globals = __main__.__dict__
if locals is None:
locals = globals
self.reset()
sys.settrace(self.trace_dispatch)
if not isinstance(cmd, types.CodeType):
cmd = cmd+'\n'
try:
exec cmd in globals, locals
except BdbQuit:
pass
finally:
self.quitting = 1
sys.settrace(None)
def runeval(self, expr, globals=None, locals=None):
if globals is None:
import __main__
globals = __main__.__dict__
if locals is None:
locals = globals
self.reset()
sys.settrace(self.trace_dispatch)
if not isinstance(expr, types.CodeType):
expr = expr+'\n'
try:
return eval(expr, globals, locals)
except BdbQuit:
pass
finally:
self.quitting = 1
sys.settrace(None)
def runctx(self, cmd, globals, locals):
# B/W compatibility
self.run(cmd, globals, locals)
# This method is more useful to debug a single function call.
def runcall(self, func, *args, **kwds):
self.reset()
sys.settrace(self.trace_dispatch)
res = None
try:
res = func(*args, **kwds)
except BdbQuit:
pass
finally:
self.quitting = 1
sys.settrace(None)
return res
def set_trace():
Bdb().set_trace()
class Breakpoint:
"""Breakpoint class
Implements temporary breakpoints, ignore counts, disabling and
(re)-enabling, and conditionals.
Breakpoints are indexed by number through bpbynumber and by
the file,line tuple using bplist. The former points to a
single instance of class Breakpoint. The latter points to a
list of such instances since there may be more than one
breakpoint per line.
"""
# XXX Keeping state in the class is a mistake -- this means
# you cannot have more than one active Bdb instance.
next = 1 # Next bp to be assigned
bplist = {} # indexed by (file, lineno) tuple
bpbynumber = [None] # Each entry is None or an instance of Bpt
# index 0 is unused, except for marking an
# effective break .... see effective()
def __init__(self, file, line, temporary=0, cond=None, funcname=None):
self.funcname = funcname
# Needed if funcname is not None.
self.func_first_executable_line = None
self.file = file # This better be in canonical form!
self.line = line
self.temporary = temporary
self.cond = cond
self.enabled = 1
self.ignore = 0
self.hits = 0
self.number = Breakpoint.next
Breakpoint.next = Breakpoint.next + 1
# Build the two lists
self.bpbynumber.append(self)
if (file, line) in self.bplist:
self.bplist[file, line].append(self)
else:
self.bplist[file, line] = [self]
def deleteMe(self):
index = (self.file, self.line)
self.bpbynumber[self.number] = None # No longer in list
self.bplist[index].remove(self)
if not self.bplist[index]:
# No more bp for this f:l combo
del self.bplist[index]
def enable(self):
self.enabled = 1
def disable(self):
self.enabled = 0
def bpprint(self, out=None):
if out is None:
out = sys.stdout
if self.temporary:
disp = 'del '
else:
disp = 'keep '
if self.enabled:
disp = disp + 'yes '
else:
disp = disp + 'no '
print >>out, '%-4dbreakpoint %s at %s:%d' % (self.number, disp,
self.file, self.line)
if self.cond:
print >>out, '\tstop only if %s' % (self.cond,)
if self.ignore:
print >>out, '\tignore next %d hits' % (self.ignore)
if (self.hits):
if (self.hits > 1): ss = 's'
else: ss = ''
print >>out, ('\tbreakpoint already hit %d time%s' %
(self.hits, ss))
# -----------end of Breakpoint class----------
def checkfuncname(b, frame):
"""Check whether we should break here because of `b.funcname`."""
if not b.funcname:
# Breakpoint was set via line number.
if b.line != frame.f_lineno:
# Breakpoint was set at a line with a def statement and the function
# defined is called: don't break.
return False
return True
# Breakpoint set via function name.
if frame.f_code.co_name != b.funcname:
# It's not a function call, but rather execution of def statement.
return False
# We are in the right frame.
if not b.func_first_executable_line:
# The function is entered for the 1st time.
b.func_first_executable_line = frame.f_lineno
if b.func_first_executable_line != frame.f_lineno:
# But we are not at the first line number: don't break.
return False
return True
# Determines if there is an effective (active) breakpoint at this
# line of code. Returns breakpoint number or 0 if none
def effective(file, line, frame):
"""Determine which breakpoint for this file:line is to be acted upon.
Called only if we know there is a bpt at this
location. Returns breakpoint that was triggered and a flag
that indicates if it is ok to delete a temporary bp.
"""
possibles = Breakpoint.bplist[file,line]
for i in range(0, len(possibles)):
b = possibles[i]
if b.enabled == 0:
continue
if not checkfuncname(b, frame):
continue
# Count every hit when bp is enabled
b.hits = b.hits + 1
if not b.cond:
# If unconditional, and ignoring,
# go on to next, else break
if b.ignore > 0:
b.ignore = b.ignore -1
continue
else:
# breakpoint and marker that's ok
# to delete if temporary
return (b,1)
else:
# Conditional bp.
# Ignore count applies only to those bpt hits where the
# condition evaluates to true.
try:
val = eval(b.cond, frame.f_globals,
frame.f_locals)
if val:
if b.ignore > 0:
b.ignore = b.ignore -1
# continue
else:
return (b,1)
# else:
# continue
except:
# if eval fails, most conservative
# thing is to stop on breakpoint
# regardless of ignore count.
# Don't delete temporary,
# as another hint to user.
return (b,0)
return (None, None)
# -------------------- testing --------------------
class Tdb(Bdb):
def user_call(self, frame, args):
name = frame.f_code.co_name
if not name: name = '???'
print '+++ call', name, args
def user_line(self, frame):
import linecache
name = frame.f_code.co_name
if not name: name = '???'
fn = self.canonic(frame.f_code.co_filename)
line = linecache.getline(fn, frame.f_lineno, frame.f_globals)
print '+++', fn, frame.f_lineno, name, ':', line.strip()
def user_return(self, frame, retval):
print '+++ return', retval
def user_exception(self, frame, exc_stuff):
print '+++ exception', exc_stuff
self.set_continue()
def foo(n):
print 'foo(', n, ')'
x = bar(n*10)
print 'bar returned', x
def bar(a):
print 'bar(', a, ')'
return a/2
def test():
t = Tdb()
t.run('import bdb; bdb.foo(10)')
# end

508
src/python/stdlib/binhex.py Normal file
View File

@@ -0,0 +1,508 @@
"""Macintosh binhex compression/decompression.
easy interface:
binhex(inputfilename, outputfilename)
hexbin(inputfilename, outputfilename)
"""
#
# Jack Jansen, CWI, August 1995.
#
# The module is supposed to be as compatible as possible. Especially the
# easy interface should work "as expected" on any platform.
# XXXX Note: currently, textfiles appear in mac-form on all platforms.
# We seem to lack a simple character-translate in python.
# (we should probably use ISO-Latin-1 on all but the mac platform).
# XXXX The simple routines are too simple: they expect to hold the complete
# files in-core. Should be fixed.
# XXXX It would be nice to handle AppleDouble format on unix
# (for servers serving macs).
# XXXX I don't understand what happens when you get 0x90 times the same byte on
# input. The resulting code (xx 90 90) would appear to be interpreted as an
# escaped *value* of 0x90. All coders I've seen appear to ignore this nicety...
#
import sys
import os
import struct
import binascii
__all__ = ["binhex","hexbin","Error"]
class Error(Exception):
pass
# States (what have we written)
[_DID_HEADER, _DID_DATA, _DID_RSRC] = range(3)
# Various constants
REASONABLY_LARGE=32768 # Minimal amount we pass the rle-coder
LINELEN=64
RUNCHAR=chr(0x90) # run-length introducer
#
# This code is no longer byte-order dependent
#
# Workarounds for non-mac machines.
try:
from Carbon.File import FSSpec, FInfo
from MacOS import openrf
def getfileinfo(name):
finfo = FSSpec(name).FSpGetFInfo()
dir, file = os.path.split(name)
# XXX Get resource/data sizes
fp = open(name, 'rb')
fp.seek(0, 2)
dlen = fp.tell()
fp = openrf(name, '*rb')
fp.seek(0, 2)
rlen = fp.tell()
return file, finfo, dlen, rlen
def openrsrc(name, *mode):
if not mode:
mode = '*rb'
else:
mode = '*' + mode[0]
return openrf(name, mode)
except ImportError:
#
# Glue code for non-macintosh usage
#
class FInfo:
def __init__(self):
self.Type = '????'
self.Creator = '????'
self.Flags = 0
def getfileinfo(name):
finfo = FInfo()
# Quick check for textfile
fp = open(name)
data = open(name).read(256)
for c in data:
if not c.isspace() and (c<' ' or ord(c) > 0x7f):
break
else:
finfo.Type = 'TEXT'
fp.seek(0, 2)
dsize = fp.tell()
fp.close()
dir, file = os.path.split(name)
file = file.replace(':', '-', 1)
return file, finfo, dsize, 0
class openrsrc:
def __init__(self, *args):
pass
def read(self, *args):
return ''
def write(self, *args):
pass
def close(self):
pass
class _Hqxcoderengine:
"""Write data to the coder in 3-byte chunks"""
def __init__(self, ofp):
self.ofp = ofp
self.data = ''
self.hqxdata = ''
self.linelen = LINELEN-1
def write(self, data):
self.data = self.data + data
datalen = len(self.data)
todo = (datalen//3)*3
data = self.data[:todo]
self.data = self.data[todo:]
if not data:
return
self.hqxdata = self.hqxdata + binascii.b2a_hqx(data)
self._flush(0)
def _flush(self, force):
first = 0
while first <= len(self.hqxdata)-self.linelen:
last = first + self.linelen
self.ofp.write(self.hqxdata[first:last]+'\n')
self.linelen = LINELEN
first = last
self.hqxdata = self.hqxdata[first:]
if force:
self.ofp.write(self.hqxdata + ':\n')
def close(self):
if self.data:
self.hqxdata = \
self.hqxdata + binascii.b2a_hqx(self.data)
self._flush(1)
self.ofp.close()
del self.ofp
class _Rlecoderengine:
"""Write data to the RLE-coder in suitably large chunks"""
def __init__(self, ofp):
self.ofp = ofp
self.data = ''
def write(self, data):
self.data = self.data + data
if len(self.data) < REASONABLY_LARGE:
return
rledata = binascii.rlecode_hqx(self.data)
self.ofp.write(rledata)
self.data = ''
def close(self):
if self.data:
rledata = binascii.rlecode_hqx(self.data)
self.ofp.write(rledata)
self.ofp.close()
del self.ofp
class BinHex:
def __init__(self, name_finfo_dlen_rlen, ofp):
name, finfo, dlen, rlen = name_finfo_dlen_rlen
if type(ofp) == type(''):
ofname = ofp
ofp = open(ofname, 'w')
ofp.write('(This file must be converted with BinHex 4.0)\n\n:')
hqxer = _Hqxcoderengine(ofp)
self.ofp = _Rlecoderengine(hqxer)
self.crc = 0
if finfo is None:
finfo = FInfo()
self.dlen = dlen
self.rlen = rlen
self._writeinfo(name, finfo)
self.state = _DID_HEADER
def _writeinfo(self, name, finfo):
nl = len(name)
if nl > 63:
raise Error, 'Filename too long'
d = chr(nl) + name + '\0'
d2 = finfo.Type + finfo.Creator
# Force all structs to be packed with big-endian
d3 = struct.pack('>h', finfo.Flags)
d4 = struct.pack('>ii', self.dlen, self.rlen)
info = d + d2 + d3 + d4
self._write(info)
self._writecrc()
def _write(self, data):
self.crc = binascii.crc_hqx(data, self.crc)
self.ofp.write(data)
def _writecrc(self):
# XXXX Should this be here??
# self.crc = binascii.crc_hqx('\0\0', self.crc)
if self.crc < 0:
fmt = '>h'
else:
fmt = '>H'
self.ofp.write(struct.pack(fmt, self.crc))
self.crc = 0
def write(self, data):
if self.state != _DID_HEADER:
raise Error, 'Writing data at the wrong time'
self.dlen = self.dlen - len(data)
self._write(data)
def close_data(self):
if self.dlen != 0:
raise Error, 'Incorrect data size, diff=%r' % (self.rlen,)
self._writecrc()
self.state = _DID_DATA
def write_rsrc(self, data):
if self.state < _DID_DATA:
self.close_data()
if self.state != _DID_DATA:
raise Error, 'Writing resource data at the wrong time'
self.rlen = self.rlen - len(data)
self._write(data)
def close(self):
if self.state < _DID_DATA:
self.close_data()
if self.state != _DID_DATA:
raise Error, 'Close at the wrong time'
if self.rlen != 0:
raise Error, \
"Incorrect resource-datasize, diff=%r" % (self.rlen,)
self._writecrc()
self.ofp.close()
self.state = None
del self.ofp
def binhex(inp, out):
"""(infilename, outfilename) - Create binhex-encoded copy of a file"""
finfo = getfileinfo(inp)
ofp = BinHex(finfo, out)
ifp = open(inp, 'rb')
# XXXX Do textfile translation on non-mac systems
while 1:
d = ifp.read(128000)
if not d: break
ofp.write(d)
ofp.close_data()
ifp.close()
ifp = openrsrc(inp, 'rb')
while 1:
d = ifp.read(128000)
if not d: break
ofp.write_rsrc(d)
ofp.close()
ifp.close()
class _Hqxdecoderengine:
"""Read data via the decoder in 4-byte chunks"""
def __init__(self, ifp):
self.ifp = ifp
self.eof = 0
def read(self, totalwtd):
"""Read at least wtd bytes (or until EOF)"""
decdata = ''
wtd = totalwtd
#
# The loop here is convoluted, since we don't really now how
# much to decode: there may be newlines in the incoming data.
while wtd > 0:
if self.eof: return decdata
wtd = ((wtd+2)//3)*4
data = self.ifp.read(wtd)
#
# Next problem: there may not be a complete number of
# bytes in what we pass to a2b. Solve by yet another
# loop.
#
while 1:
try:
decdatacur, self.eof = \
binascii.a2b_hqx(data)
break
except binascii.Incomplete:
pass
newdata = self.ifp.read(1)
if not newdata:
raise Error, \
'Premature EOF on binhex file'
data = data + newdata
decdata = decdata + decdatacur
wtd = totalwtd - len(decdata)
if not decdata and not self.eof:
raise Error, 'Premature EOF on binhex file'
return decdata
def close(self):
self.ifp.close()
class _Rledecoderengine:
"""Read data via the RLE-coder"""
def __init__(self, ifp):
self.ifp = ifp
self.pre_buffer = ''
self.post_buffer = ''
self.eof = 0
def read(self, wtd):
if wtd > len(self.post_buffer):
self._fill(wtd-len(self.post_buffer))
rv = self.post_buffer[:wtd]
self.post_buffer = self.post_buffer[wtd:]
return rv
def _fill(self, wtd):
self.pre_buffer = self.pre_buffer + self.ifp.read(wtd+4)
if self.ifp.eof:
self.post_buffer = self.post_buffer + \
binascii.rledecode_hqx(self.pre_buffer)
self.pre_buffer = ''
return
#
# Obfuscated code ahead. We have to take care that we don't
# end up with an orphaned RUNCHAR later on. So, we keep a couple
# of bytes in the buffer, depending on what the end of
# the buffer looks like:
# '\220\0\220' - Keep 3 bytes: repeated \220 (escaped as \220\0)
# '?\220' - Keep 2 bytes: repeated something-else
# '\220\0' - Escaped \220: Keep 2 bytes.
# '?\220?' - Complete repeat sequence: decode all
# otherwise: keep 1 byte.
#
mark = len(self.pre_buffer)
if self.pre_buffer[-3:] == RUNCHAR + '\0' + RUNCHAR:
mark = mark - 3
elif self.pre_buffer[-1] == RUNCHAR:
mark = mark - 2
elif self.pre_buffer[-2:] == RUNCHAR + '\0':
mark = mark - 2
elif self.pre_buffer[-2] == RUNCHAR:
pass # Decode all
else:
mark = mark - 1
self.post_buffer = self.post_buffer + \
binascii.rledecode_hqx(self.pre_buffer[:mark])
self.pre_buffer = self.pre_buffer[mark:]
def close(self):
self.ifp.close()
class HexBin:
def __init__(self, ifp):
if type(ifp) == type(''):
ifp = open(ifp)
#
# Find initial colon.
#
while 1:
ch = ifp.read(1)
if not ch:
raise Error, "No binhex data found"
# Cater for \r\n terminated lines (which show up as \n\r, hence
# all lines start with \r)
if ch == '\r':
continue
if ch == ':':
break
if ch != '\n':
dummy = ifp.readline()
hqxifp = _Hqxdecoderengine(ifp)
self.ifp = _Rledecoderengine(hqxifp)
self.crc = 0
self._readheader()
def _read(self, len):
data = self.ifp.read(len)
self.crc = binascii.crc_hqx(data, self.crc)
return data
def _checkcrc(self):
filecrc = struct.unpack('>h', self.ifp.read(2))[0] & 0xffff
#self.crc = binascii.crc_hqx('\0\0', self.crc)
# XXXX Is this needed??
self.crc = self.crc & 0xffff
if filecrc != self.crc:
raise Error, 'CRC error, computed %x, read %x' \
%(self.crc, filecrc)
self.crc = 0
def _readheader(self):
len = self._read(1)
fname = self._read(ord(len))
rest = self._read(1+4+4+2+4+4)
self._checkcrc()
type = rest[1:5]
creator = rest[5:9]
flags = struct.unpack('>h', rest[9:11])[0]
self.dlen = struct.unpack('>l', rest[11:15])[0]
self.rlen = struct.unpack('>l', rest[15:19])[0]
self.FName = fname
self.FInfo = FInfo()
self.FInfo.Creator = creator
self.FInfo.Type = type
self.FInfo.Flags = flags
self.state = _DID_HEADER
def read(self, *n):
if self.state != _DID_HEADER:
raise Error, 'Read data at wrong time'
if n:
n = n[0]
n = min(n, self.dlen)
else:
n = self.dlen
rv = ''
while len(rv) < n:
rv = rv + self._read(n-len(rv))
self.dlen = self.dlen - n
return rv
def close_data(self):
if self.state != _DID_HEADER:
raise Error, 'close_data at wrong time'
if self.dlen:
dummy = self._read(self.dlen)
self._checkcrc()
self.state = _DID_DATA
def read_rsrc(self, *n):
if self.state == _DID_HEADER:
self.close_data()
if self.state != _DID_DATA:
raise Error, 'Read resource data at wrong time'
if n:
n = n[0]
n = min(n, self.rlen)
else:
n = self.rlen
self.rlen = self.rlen - n
return self._read(n)
def close(self):
if self.rlen:
dummy = self.read_rsrc(self.rlen)
self._checkcrc()
self.state = _DID_RSRC
self.ifp.close()
def hexbin(inp, out):
"""(infilename, outfilename) - Decode binhexed file"""
ifp = HexBin(inp)
finfo = ifp.FInfo
if not out:
out = ifp.FName
ofp = open(out, 'wb')
# XXXX Do translation on non-mac systems
while 1:
d = ifp.read(128000)
if not d: break
ofp.write(d)
ofp.close()
ifp.close_data()
d = ifp.read_rsrc(128000)
if d:
ofp = openrsrc(out, 'wb')
ofp.write(d)
while 1:
d = ifp.read_rsrc(128000)
if not d: break
ofp.write(d)
ofp.close()
ifp.close()
def _test():
fname = sys.argv[1]
binhex(fname, fname+'.hqx')
hexbin(fname+'.hqx', fname+'.viahqx')
#hexbin(fname, fname+'.unpacked')
sys.exit(1)
if __name__ == '__main__':
_test()

View File

@@ -0,0 +1,92 @@
"""Bisection algorithms."""
def insort_right(a, x, lo=0, hi=None):
"""Insert item x in list a, and keep it sorted assuming a is sorted.
If x is already in a, insert it to the right of the rightmost x.
Optional args lo (default 0) and hi (default len(a)) bound the
slice of a to be searched.
"""
if lo < 0:
raise ValueError('lo must be non-negative')
if hi is None:
hi = len(a)
while lo < hi:
mid = (lo+hi)//2
if x < a[mid]: hi = mid
else: lo = mid+1
a.insert(lo, x)
insort = insort_right # backward compatibility
def bisect_right(a, x, lo=0, hi=None):
"""Return the index where to insert item x in list a, assuming a is sorted.
The return value i is such that all e in a[:i] have e <= x, and all e in
a[i:] have e > x. So if x already appears in the list, a.insert(x) will
insert just after the rightmost x already there.
Optional args lo (default 0) and hi (default len(a)) bound the
slice of a to be searched.
"""
if lo < 0:
raise ValueError('lo must be non-negative')
if hi is None:
hi = len(a)
while lo < hi:
mid = (lo+hi)//2
if x < a[mid]: hi = mid
else: lo = mid+1
return lo
bisect = bisect_right # backward compatibility
def insort_left(a, x, lo=0, hi=None):
"""Insert item x in list a, and keep it sorted assuming a is sorted.
If x is already in a, insert it to the left of the leftmost x.
Optional args lo (default 0) and hi (default len(a)) bound the
slice of a to be searched.
"""
if lo < 0:
raise ValueError('lo must be non-negative')
if hi is None:
hi = len(a)
while lo < hi:
mid = (lo+hi)//2
if a[mid] < x: lo = mid+1
else: hi = mid
a.insert(lo, x)
def bisect_left(a, x, lo=0, hi=None):
"""Return the index where to insert item x in list a, assuming a is sorted.
The return value i is such that all e in a[:i] have e < x, and all e in
a[i:] have e >= x. So if x already appears in the list, a.insert(x) will
insert just before the leftmost x already there.
Optional args lo (default 0) and hi (default len(a)) bound the
slice of a to be searched.
"""
if lo < 0:
raise ValueError('lo must be non-negative')
if hi is None:
hi = len(a)
while lo < hi:
mid = (lo+hi)//2
if a[mid] < x: lo = mid+1
else: hi = mid
return lo
# Overwrite above definitions with a fast C implementation
try:
from _bisect import *
except ImportError:
pass

199
src/python/stdlib/cProfile.py Executable file
View File

@@ -0,0 +1,199 @@
#! /usr/bin/env python
"""Python interface for the 'lsprof' profiler.
Compatible with the 'profile' module.
"""
__all__ = ["run", "runctx", "help", "Profile"]
import _lsprof
# ____________________________________________________________
# Simple interface
def run(statement, filename=None, sort=-1):
"""Run statement under profiler optionally saving results in filename
This function takes a single argument that can be passed to the
"exec" statement, and an optional file name. In all cases this
routine attempts to "exec" its first argument and gather profiling
statistics from the execution. If no file name is present, then this
function automatically prints a simple profiling report, sorted by the
standard name string (file/line/function-name) that is presented in
each line.
"""
prof = Profile()
result = None
try:
try:
prof = prof.run(statement)
except SystemExit:
pass
finally:
if filename is not None:
prof.dump_stats(filename)
else:
result = prof.print_stats(sort)
return result
def runctx(statement, globals, locals, filename=None, sort=-1):
"""Run statement under profiler, supplying your own globals and locals,
optionally saving results in filename.
statement and filename have the same semantics as profile.run
"""
prof = Profile()
result = None
try:
try:
prof = prof.runctx(statement, globals, locals)
except SystemExit:
pass
finally:
if filename is not None:
prof.dump_stats(filename)
else:
result = prof.print_stats(sort)
return result
# Backwards compatibility.
def help():
print "Documentation for the profile/cProfile modules can be found "
print "in the Python Library Reference, section 'The Python Profiler'."
# ____________________________________________________________
class Profile(_lsprof.Profiler):
"""Profile(custom_timer=None, time_unit=None, subcalls=True, builtins=True)
Builds a profiler object using the specified timer function.
The default timer is a fast built-in one based on real time.
For custom timer functions returning integers, time_unit can
be a float specifying a scale (i.e. how long each integer unit
is, in seconds).
"""
# Most of the functionality is in the base class.
# This subclass only adds convenient and backward-compatible methods.
def print_stats(self, sort=-1):
import pstats
pstats.Stats(self).strip_dirs().sort_stats(sort).print_stats()
def dump_stats(self, file):
import marshal
f = open(file, 'wb')
self.create_stats()
marshal.dump(self.stats, f)
f.close()
def create_stats(self):
self.disable()
self.snapshot_stats()
def snapshot_stats(self):
entries = self.getstats()
self.stats = {}
callersdicts = {}
# call information
for entry in entries:
func = label(entry.code)
nc = entry.callcount # ncalls column of pstats (before '/')
cc = nc - entry.reccallcount # ncalls column of pstats (after '/')
tt = entry.inlinetime # tottime column of pstats
ct = entry.totaltime # cumtime column of pstats
callers = {}
callersdicts[id(entry.code)] = callers
self.stats[func] = cc, nc, tt, ct, callers
# subcall information
for entry in entries:
if entry.calls:
func = label(entry.code)
for subentry in entry.calls:
try:
callers = callersdicts[id(subentry.code)]
except KeyError:
continue
nc = subentry.callcount
cc = nc - subentry.reccallcount
tt = subentry.inlinetime
ct = subentry.totaltime
if func in callers:
prev = callers[func]
nc += prev[0]
cc += prev[1]
tt += prev[2]
ct += prev[3]
callers[func] = nc, cc, tt, ct
# The following two methods can be called by clients to use
# a profiler to profile a statement, given as a string.
def run(self, cmd):
import __main__
dict = __main__.__dict__
return self.runctx(cmd, dict, dict)
def runctx(self, cmd, globals, locals):
self.enable()
try:
exec cmd in globals, locals
finally:
self.disable()
return self
# This method is more useful to profile a single function call.
def runcall(self, func, *args, **kw):
self.enable()
try:
return func(*args, **kw)
finally:
self.disable()
# ____________________________________________________________
def label(code):
if isinstance(code, str):
return ('~', 0, code) # built-in functions ('~' sorts at the end)
else:
return (code.co_filename, code.co_firstlineno, code.co_name)
# ____________________________________________________________
def main():
import os, sys
from optparse import OptionParser
usage = "cProfile.py [-o output_file_path] [-s sort] scriptfile [arg] ..."
parser = OptionParser(usage=usage)
parser.allow_interspersed_args = False
parser.add_option('-o', '--outfile', dest="outfile",
help="Save stats to <outfile>", default=None)
parser.add_option('-s', '--sort', dest="sort",
help="Sort order when printing to stdout, based on pstats.Stats class",
default=-1)
if not sys.argv[1:]:
parser.print_usage()
sys.exit(2)
(options, args) = parser.parse_args()
sys.argv[:] = args
if len(args) > 0:
progname = args[0]
sys.path.insert(0, os.path.dirname(progname))
with open(progname, 'rb') as fp:
code = compile(fp.read(), progname, 'exec')
globs = {
'__file__': progname,
'__name__': '__main__',
'__package__': None,
}
runctx(code, globs, None, options.outfile, options.sort)
else:
parser.print_usage()
return parser
# When invoked as main program, invoke the profiler on a script
if __name__ == '__main__':
main()

View File

@@ -0,0 +1,708 @@
"""Calendar printing functions
Note when comparing these calendars to the ones printed by cal(1): By
default, these calendars have Monday as the first day of the week, and
Sunday as the last (the European convention). Use setfirstweekday() to
set the first day of the week (0=Monday, 6=Sunday)."""
import sys
import datetime
import locale as _locale
__all__ = ["IllegalMonthError", "IllegalWeekdayError", "setfirstweekday",
"firstweekday", "isleap", "leapdays", "weekday", "monthrange",
"monthcalendar", "prmonth", "month", "prcal", "calendar",
"timegm", "month_name", "month_abbr", "day_name", "day_abbr"]
# Exception raised for bad input (with string parameter for details)
error = ValueError
# Exceptions raised for bad input
class IllegalMonthError(ValueError):
def __init__(self, month):
self.month = month
def __str__(self):
return "bad month number %r; must be 1-12" % self.month
class IllegalWeekdayError(ValueError):
def __init__(self, weekday):
self.weekday = weekday
def __str__(self):
return "bad weekday number %r; must be 0 (Monday) to 6 (Sunday)" % self.weekday
# Constants for months referenced later
January = 1
February = 2
# Number of days per month (except for February in leap years)
mdays = [0, 31, 28, 31, 30, 31, 30, 31, 31, 30, 31, 30, 31]
# This module used to have hard-coded lists of day and month names, as
# English strings. The classes following emulate a read-only version of
# that, but supply localized names. Note that the values are computed
# fresh on each call, in case the user changes locale between calls.
class _localized_month:
_months = [datetime.date(2001, i+1, 1).strftime for i in range(12)]
_months.insert(0, lambda x: "")
def __init__(self, format):
self.format = format
def __getitem__(self, i):
funcs = self._months[i]
if isinstance(i, slice):
return [f(self.format) for f in funcs]
else:
return funcs(self.format)
def __len__(self):
return 13
class _localized_day:
# January 1, 2001, was a Monday.
_days = [datetime.date(2001, 1, i+1).strftime for i in range(7)]
def __init__(self, format):
self.format = format
def __getitem__(self, i):
funcs = self._days[i]
if isinstance(i, slice):
return [f(self.format) for f in funcs]
else:
return funcs(self.format)
def __len__(self):
return 7
# Full and abbreviated names of weekdays
day_name = _localized_day('%A')
day_abbr = _localized_day('%a')
# Full and abbreviated names of months (1-based arrays!!!)
month_name = _localized_month('%B')
month_abbr = _localized_month('%b')
# Constants for weekdays
(MONDAY, TUESDAY, WEDNESDAY, THURSDAY, FRIDAY, SATURDAY, SUNDAY) = range(7)
def isleap(year):
"""Return True for leap years, False for non-leap years."""
return year % 4 == 0 and (year % 100 != 0 or year % 400 == 0)
def leapdays(y1, y2):
"""Return number of leap years in range [y1, y2).
Assume y1 <= y2."""
y1 -= 1
y2 -= 1
return (y2//4 - y1//4) - (y2//100 - y1//100) + (y2//400 - y1//400)
def weekday(year, month, day):
"""Return weekday (0-6 ~ Mon-Sun) for year (1970-...), month (1-12),
day (1-31)."""
return datetime.date(year, month, day).weekday()
def monthrange(year, month):
"""Return weekday (0-6 ~ Mon-Sun) and number of days (28-31) for
year, month."""
if not 1 <= month <= 12:
raise IllegalMonthError(month)
day1 = weekday(year, month, 1)
ndays = mdays[month] + (month == February and isleap(year))
return day1, ndays
class Calendar(object):
"""
Base calendar class. This class doesn't do any formatting. It simply
provides data to subclasses.
"""
def __init__(self, firstweekday=0):
self.firstweekday = firstweekday # 0 = Monday, 6 = Sunday
def getfirstweekday(self):
return self._firstweekday % 7
def setfirstweekday(self, firstweekday):
self._firstweekday = firstweekday
firstweekday = property(getfirstweekday, setfirstweekday)
def iterweekdays(self):
"""
Return a iterator for one week of weekday numbers starting with the
configured first one.
"""
for i in range(self.firstweekday, self.firstweekday + 7):
yield i%7
def itermonthdates(self, year, month):
"""
Return an iterator for one month. The iterator will yield datetime.date
values and will always iterate through complete weeks, so it will yield
dates outside the specified month.
"""
date = datetime.date(year, month, 1)
# Go back to the beginning of the week
days = (date.weekday() - self.firstweekday) % 7
date -= datetime.timedelta(days=days)
oneday = datetime.timedelta(days=1)
while True:
yield date
date += oneday
if date.month != month and date.weekday() == self.firstweekday:
break
def itermonthdays2(self, year, month):
"""
Like itermonthdates(), but will yield (day number, weekday number)
tuples. For days outside the specified month the day number is 0.
"""
for date in self.itermonthdates(year, month):
if date.month != month:
yield (0, date.weekday())
else:
yield (date.day, date.weekday())
def itermonthdays(self, year, month):
"""
Like itermonthdates(), but will yield day numbers. For days outside
the specified month the day number is 0.
"""
for date in self.itermonthdates(year, month):
if date.month != month:
yield 0
else:
yield date.day
def monthdatescalendar(self, year, month):
"""
Return a matrix (list of lists) representing a month's calendar.
Each row represents a week; week entries are datetime.date values.
"""
dates = list(self.itermonthdates(year, month))
return [ dates[i:i+7] for i in range(0, len(dates), 7) ]
def monthdays2calendar(self, year, month):
"""
Return a matrix representing a month's calendar.
Each row represents a week; week entries are
(day number, weekday number) tuples. Day numbers outside this month
are zero.
"""
days = list(self.itermonthdays2(year, month))
return [ days[i:i+7] for i in range(0, len(days), 7) ]
def monthdayscalendar(self, year, month):
"""
Return a matrix representing a month's calendar.
Each row represents a week; days outside this month are zero.
"""
days = list(self.itermonthdays(year, month))
return [ days[i:i+7] for i in range(0, len(days), 7) ]
def yeardatescalendar(self, year, width=3):
"""
Return the data for the specified year ready for formatting. The return
value is a list of month rows. Each month row contains upto width months.
Each month contains between 4 and 6 weeks and each week contains 1-7
days. Days are datetime.date objects.
"""
months = [
self.monthdatescalendar(year, i)
for i in range(January, January+12)
]
return [months[i:i+width] for i in range(0, len(months), width) ]
def yeardays2calendar(self, year, width=3):
"""
Return the data for the specified year ready for formatting (similar to
yeardatescalendar()). Entries in the week lists are
(day number, weekday number) tuples. Day numbers outside this month are
zero.
"""
months = [
self.monthdays2calendar(year, i)
for i in range(January, January+12)
]
return [months[i:i+width] for i in range(0, len(months), width) ]
def yeardayscalendar(self, year, width=3):
"""
Return the data for the specified year ready for formatting (similar to
yeardatescalendar()). Entries in the week lists are day numbers.
Day numbers outside this month are zero.
"""
months = [
self.monthdayscalendar(year, i)
for i in range(January, January+12)
]
return [months[i:i+width] for i in range(0, len(months), width) ]
class TextCalendar(Calendar):
"""
Subclass of Calendar that outputs a calendar as a simple plain text
similar to the UNIX program cal.
"""
def prweek(self, theweek, width):
"""
Print a single week (no newline).
"""
print self.formatweek(theweek, width),
def formatday(self, day, weekday, width):
"""
Returns a formatted day.
"""
if day == 0:
s = ''
else:
s = '%2i' % day # right-align single-digit days
return s.center(width)
def formatweek(self, theweek, width):
"""
Returns a single week in a string (no newline).
"""
return ' '.join(self.formatday(d, wd, width) for (d, wd) in theweek)
def formatweekday(self, day, width):
"""
Returns a formatted week day name.
"""
if width >= 9:
names = day_name
else:
names = day_abbr
return names[day][:width].center(width)
def formatweekheader(self, width):
"""
Return a header for a week.
"""
return ' '.join(self.formatweekday(i, width) for i in self.iterweekdays())
def formatmonthname(self, theyear, themonth, width, withyear=True):
"""
Return a formatted month name.
"""
s = month_name[themonth]
if withyear:
s = "%s %r" % (s, theyear)
return s.center(width)
def prmonth(self, theyear, themonth, w=0, l=0):
"""
Print a month's calendar.
"""
print self.formatmonth(theyear, themonth, w, l),
def formatmonth(self, theyear, themonth, w=0, l=0):
"""
Return a month's calendar string (multi-line).
"""
w = max(2, w)
l = max(1, l)
s = self.formatmonthname(theyear, themonth, 7 * (w + 1) - 1)
s = s.rstrip()
s += '\n' * l
s += self.formatweekheader(w).rstrip()
s += '\n' * l
for week in self.monthdays2calendar(theyear, themonth):
s += self.formatweek(week, w).rstrip()
s += '\n' * l
return s
def formatyear(self, theyear, w=2, l=1, c=6, m=3):
"""
Returns a year's calendar as a multi-line string.
"""
w = max(2, w)
l = max(1, l)
c = max(2, c)
colwidth = (w + 1) * 7 - 1
v = []
a = v.append
a(repr(theyear).center(colwidth*m+c*(m-1)).rstrip())
a('\n'*l)
header = self.formatweekheader(w)
for (i, row) in enumerate(self.yeardays2calendar(theyear, m)):
# months in this row
months = range(m*i+1, min(m*(i+1)+1, 13))
a('\n'*l)
names = (self.formatmonthname(theyear, k, colwidth, False)
for k in months)
a(formatstring(names, colwidth, c).rstrip())
a('\n'*l)
headers = (header for k in months)
a(formatstring(headers, colwidth, c).rstrip())
a('\n'*l)
# max number of weeks for this row
height = max(len(cal) for cal in row)
for j in range(height):
weeks = []
for cal in row:
if j >= len(cal):
weeks.append('')
else:
weeks.append(self.formatweek(cal[j], w))
a(formatstring(weeks, colwidth, c).rstrip())
a('\n' * l)
return ''.join(v)
def pryear(self, theyear, w=0, l=0, c=6, m=3):
"""Print a year's calendar."""
print self.formatyear(theyear, w, l, c, m)
class HTMLCalendar(Calendar):
"""
This calendar returns complete HTML pages.
"""
# CSS classes for the day <td>s
cssclasses = ["mon", "tue", "wed", "thu", "fri", "sat", "sun"]
def formatday(self, day, weekday):
"""
Return a day as a table cell.
"""
if day == 0:
return '<td class="noday">&nbsp;</td>' # day outside month
else:
return '<td class="%s">%d</td>' % (self.cssclasses[weekday], day)
def formatweek(self, theweek):
"""
Return a complete week as a table row.
"""
s = ''.join(self.formatday(d, wd) for (d, wd) in theweek)
return '<tr>%s</tr>' % s
def formatweekday(self, day):
"""
Return a weekday name as a table header.
"""
return '<th class="%s">%s</th>' % (self.cssclasses[day], day_abbr[day])
def formatweekheader(self):
"""
Return a header for a week as a table row.
"""
s = ''.join(self.formatweekday(i) for i in self.iterweekdays())
return '<tr>%s</tr>' % s
def formatmonthname(self, theyear, themonth, withyear=True):
"""
Return a month name as a table row.
"""
if withyear:
s = '%s %s' % (month_name[themonth], theyear)
else:
s = '%s' % month_name[themonth]
return '<tr><th colspan="7" class="month">%s</th></tr>' % s
def formatmonth(self, theyear, themonth, withyear=True):
"""
Return a formatted month as a table.
"""
v = []
a = v.append
a('<table border="0" cellpadding="0" cellspacing="0" class="month">')
a('\n')
a(self.formatmonthname(theyear, themonth, withyear=withyear))
a('\n')
a(self.formatweekheader())
a('\n')
for week in self.monthdays2calendar(theyear, themonth):
a(self.formatweek(week))
a('\n')
a('</table>')
a('\n')
return ''.join(v)
def formatyear(self, theyear, width=3):
"""
Return a formatted year as a table of tables.
"""
v = []
a = v.append
width = max(width, 1)
a('<table border="0" cellpadding="0" cellspacing="0" class="year">')
a('\n')
a('<tr><th colspan="%d" class="year">%s</th></tr>' % (width, theyear))
for i in range(January, January+12, width):
# months in this row
months = range(i, min(i+width, 13))
a('<tr>')
for m in months:
a('<td>')
a(self.formatmonth(theyear, m, withyear=False))
a('</td>')
a('</tr>')
a('</table>')
return ''.join(v)
def formatyearpage(self, theyear, width=3, css='calendar.css', encoding=None):
"""
Return a formatted year as a complete HTML page.
"""
if encoding is None:
encoding = sys.getdefaultencoding()
v = []
a = v.append
a('<?xml version="1.0" encoding="%s"?>\n' % encoding)
a('<!DOCTYPE html PUBLIC "-//W3C//DTD XHTML 1.0 Strict//EN" "http://www.w3.org/TR/xhtml1/DTD/xhtml1-strict.dtd">\n')
a('<html>\n')
a('<head>\n')
a('<meta http-equiv="Content-Type" content="text/html; charset=%s" />\n' % encoding)
if css is not None:
a('<link rel="stylesheet" type="text/css" href="%s" />\n' % css)
a('<title>Calendar for %d</title>\n' % theyear)
a('</head>\n')
a('<body>\n')
a(self.formatyear(theyear, width))
a('</body>\n')
a('</html>\n')
return ''.join(v).encode(encoding, "xmlcharrefreplace")
class TimeEncoding:
def __init__(self, locale):
self.locale = locale
def __enter__(self):
self.oldlocale = _locale.getlocale(_locale.LC_TIME)
_locale.setlocale(_locale.LC_TIME, self.locale)
def __exit__(self, *args):
_locale.setlocale(_locale.LC_TIME, self.oldlocale)
class LocaleTextCalendar(TextCalendar):
"""
This class can be passed a locale name in the constructor and will return
month and weekday names in the specified locale. If this locale includes
an encoding all strings containing month and weekday names will be returned
as unicode.
"""
def __init__(self, firstweekday=0, locale=None):
TextCalendar.__init__(self, firstweekday)
if locale is None:
locale = _locale.getdefaultlocale()
self.locale = locale
def formatweekday(self, day, width):
with TimeEncoding(self.locale) as encoding:
if width >= 9:
names = day_name
else:
names = day_abbr
name = names[day]
if encoding is not None:
name = name.decode(encoding)
return name[:width].center(width)
def formatmonthname(self, theyear, themonth, width, withyear=True):
with TimeEncoding(self.locale) as encoding:
s = month_name[themonth]
if encoding is not None:
s = s.decode(encoding)
if withyear:
s = "%s %r" % (s, theyear)
return s.center(width)
class LocaleHTMLCalendar(HTMLCalendar):
"""
This class can be passed a locale name in the constructor and will return
month and weekday names in the specified locale. If this locale includes
an encoding all strings containing month and weekday names will be returned
as unicode.
"""
def __init__(self, firstweekday=0, locale=None):
HTMLCalendar.__init__(self, firstweekday)
if locale is None:
locale = _locale.getdefaultlocale()
self.locale = locale
def formatweekday(self, day):
with TimeEncoding(self.locale) as encoding:
s = day_abbr[day]
if encoding is not None:
s = s.decode(encoding)
return '<th class="%s">%s</th>' % (self.cssclasses[day], s)
def formatmonthname(self, theyear, themonth, withyear=True):
with TimeEncoding(self.locale) as encoding:
s = month_name[themonth]
if encoding is not None:
s = s.decode(encoding)
if withyear:
s = '%s %s' % (s, theyear)
return '<tr><th colspan="7" class="month">%s</th></tr>' % s
# Support for old module level interface
c = TextCalendar()
firstweekday = c.getfirstweekday
def setfirstweekday(firstweekday):
try:
firstweekday.__index__
except AttributeError:
raise IllegalWeekdayError(firstweekday)
if not MONDAY <= firstweekday <= SUNDAY:
raise IllegalWeekdayError(firstweekday)
c.firstweekday = firstweekday
monthcalendar = c.monthdayscalendar
prweek = c.prweek
week = c.formatweek
weekheader = c.formatweekheader
prmonth = c.prmonth
month = c.formatmonth
calendar = c.formatyear
prcal = c.pryear
# Spacing of month columns for multi-column year calendar
_colwidth = 7*3 - 1 # Amount printed by prweek()
_spacing = 6 # Number of spaces between columns
def format(cols, colwidth=_colwidth, spacing=_spacing):
"""Prints multi-column formatting for year calendars"""
print formatstring(cols, colwidth, spacing)
def formatstring(cols, colwidth=_colwidth, spacing=_spacing):
"""Returns a string formatted from n strings, centered within n columns."""
spacing *= ' '
return spacing.join(c.center(colwidth) for c in cols)
EPOCH = 1970
_EPOCH_ORD = datetime.date(EPOCH, 1, 1).toordinal()
def timegm(tuple):
"""Unrelated but handy function to calculate Unix timestamp from GMT."""
year, month, day, hour, minute, second = tuple[:6]
days = datetime.date(year, month, 1).toordinal() - _EPOCH_ORD + day - 1
hours = days*24 + hour
minutes = hours*60 + minute
seconds = minutes*60 + second
return seconds
def main(args):
import optparse
parser = optparse.OptionParser(usage="usage: %prog [options] [year [month]]")
parser.add_option(
"-w", "--width",
dest="width", type="int", default=2,
help="width of date column (default 2, text only)"
)
parser.add_option(
"-l", "--lines",
dest="lines", type="int", default=1,
help="number of lines for each week (default 1, text only)"
)
parser.add_option(
"-s", "--spacing",
dest="spacing", type="int", default=6,
help="spacing between months (default 6, text only)"
)
parser.add_option(
"-m", "--months",
dest="months", type="int", default=3,
help="months per row (default 3, text only)"
)
parser.add_option(
"-c", "--css",
dest="css", default="calendar.css",
help="CSS to use for page (html only)"
)
parser.add_option(
"-L", "--locale",
dest="locale", default=None,
help="locale to be used from month and weekday names"
)
parser.add_option(
"-e", "--encoding",
dest="encoding", default=None,
help="Encoding to use for output"
)
parser.add_option(
"-t", "--type",
dest="type", default="text",
choices=("text", "html"),
help="output type (text or html)"
)
(options, args) = parser.parse_args(args)
if options.locale and not options.encoding:
parser.error("if --locale is specified --encoding is required")
sys.exit(1)
locale = options.locale, options.encoding
if options.type == "html":
if options.locale:
cal = LocaleHTMLCalendar(locale=locale)
else:
cal = HTMLCalendar()
encoding = options.encoding
if encoding is None:
encoding = sys.getdefaultencoding()
optdict = dict(encoding=encoding, css=options.css)
if len(args) == 1:
print cal.formatyearpage(datetime.date.today().year, **optdict)
elif len(args) == 2:
print cal.formatyearpage(int(args[1]), **optdict)
else:
parser.error("incorrect number of arguments")
sys.exit(1)
else:
if options.locale:
cal = LocaleTextCalendar(locale=locale)
else:
cal = TextCalendar()
optdict = dict(w=options.width, l=options.lines)
if len(args) != 3:
optdict["c"] = options.spacing
optdict["m"] = options.months
if len(args) == 1:
result = cal.formatyear(datetime.date.today().year, **optdict)
elif len(args) == 2:
result = cal.formatyear(int(args[1]), **optdict)
elif len(args) == 3:
result = cal.formatmonth(int(args[1]), int(args[2]), **optdict)
else:
parser.error("incorrect number of arguments")
sys.exit(1)
if options.encoding:
result = result.encode(options.encoding)
print result
if __name__ == "__main__":
main(sys.argv)

1051
src/python/stdlib/cgi.py Executable file

File diff suppressed because it is too large Load Diff

318
src/python/stdlib/cgitb.py Normal file
View File

@@ -0,0 +1,318 @@
"""More comprehensive traceback formatting for Python scripts.
To enable this module, do:
import cgitb; cgitb.enable()
at the top of your script. The optional arguments to enable() are:
display - if true, tracebacks are displayed in the web browser
logdir - if set, tracebacks are written to files in this directory
context - number of lines of source code to show for each stack frame
format - 'text' or 'html' controls the output format
By default, tracebacks are displayed but not saved, the context is 5 lines
and the output format is 'html' (for backwards compatibility with the
original use of this module)
Alternatively, if you have caught an exception and want cgitb to display it
for you, call cgitb.handler(). The optional argument to handler() is a
3-item tuple (etype, evalue, etb) just like the value of sys.exc_info().
The default handler displays output as HTML.
"""
import inspect
import keyword
import linecache
import os
import pydoc
import sys
import tempfile
import time
import tokenize
import traceback
import types
def reset():
"""Return a string that resets the CGI and browser to a known state."""
return '''<!--: spam
Content-Type: text/html
<body bgcolor="#f0f0f8"><font color="#f0f0f8" size="-5"> -->
<body bgcolor="#f0f0f8"><font color="#f0f0f8" size="-5"> --> -->
</font> </font> </font> </script> </object> </blockquote> </pre>
</table> </table> </table> </table> </table> </font> </font> </font>'''
__UNDEF__ = [] # a special sentinel object
def small(text):
if text:
return '<small>' + text + '</small>'
else:
return ''
def strong(text):
if text:
return '<strong>' + text + '</strong>'
else:
return ''
def grey(text):
if text:
return '<font color="#909090">' + text + '</font>'
else:
return ''
def lookup(name, frame, locals):
"""Find the value for a given name in the given environment."""
if name in locals:
return 'local', locals[name]
if name in frame.f_globals:
return 'global', frame.f_globals[name]
if '__builtins__' in frame.f_globals:
builtins = frame.f_globals['__builtins__']
if type(builtins) is type({}):
if name in builtins:
return 'builtin', builtins[name]
else:
if hasattr(builtins, name):
return 'builtin', getattr(builtins, name)
return None, __UNDEF__
def scanvars(reader, frame, locals):
"""Scan one logical line of Python and look up values of variables used."""
vars, lasttoken, parent, prefix, value = [], None, None, '', __UNDEF__
for ttype, token, start, end, line in tokenize.generate_tokens(reader):
if ttype == tokenize.NEWLINE: break
if ttype == tokenize.NAME and token not in keyword.kwlist:
if lasttoken == '.':
if parent is not __UNDEF__:
value = getattr(parent, token, __UNDEF__)
vars.append((prefix + token, prefix, value))
else:
where, value = lookup(token, frame, locals)
vars.append((token, where, value))
elif token == '.':
prefix += lasttoken + '.'
parent = value
else:
parent, prefix = None, ''
lasttoken = token
return vars
def html(einfo, context=5):
"""Return a nice HTML document describing a given traceback."""
etype, evalue, etb = einfo
if type(etype) is types.ClassType:
etype = etype.__name__
pyver = 'Python ' + sys.version.split()[0] + ': ' + sys.executable
date = time.ctime(time.time())
head = '<body bgcolor="#f0f0f8">' + pydoc.html.heading(
'<big><big>%s</big></big>' %
strong(pydoc.html.escape(str(etype))),
'#ffffff', '#6622aa', pyver + '<br>' + date) + '''
<p>A problem occurred in a Python script. Here is the sequence of
function calls leading up to the error, in the order they occurred.</p>'''
indent = '<tt>' + small('&nbsp;' * 5) + '&nbsp;</tt>'
frames = []
records = inspect.getinnerframes(etb, context)
for frame, file, lnum, func, lines, index in records:
if file:
file = os.path.abspath(file)
link = '<a href="file://%s">%s</a>' % (file, pydoc.html.escape(file))
else:
file = link = '?'
args, varargs, varkw, locals = inspect.getargvalues(frame)
call = ''
if func != '?':
call = 'in ' + strong(func) + \
inspect.formatargvalues(args, varargs, varkw, locals,
formatvalue=lambda value: '=' + pydoc.html.repr(value))
highlight = {}
def reader(lnum=[lnum]):
highlight[lnum[0]] = 1
try: return linecache.getline(file, lnum[0])
finally: lnum[0] += 1
vars = scanvars(reader, frame, locals)
rows = ['<tr><td bgcolor="#d8bbff">%s%s %s</td></tr>' %
('<big>&nbsp;</big>', link, call)]
if index is not None:
i = lnum - index
for line in lines:
num = small('&nbsp;' * (5-len(str(i))) + str(i)) + '&nbsp;'
if i in highlight:
line = '<tt>=&gt;%s%s</tt>' % (num, pydoc.html.preformat(line))
rows.append('<tr><td bgcolor="#ffccee">%s</td></tr>' % line)
else:
line = '<tt>&nbsp;&nbsp;%s%s</tt>' % (num, pydoc.html.preformat(line))
rows.append('<tr><td>%s</td></tr>' % grey(line))
i += 1
done, dump = {}, []
for name, where, value in vars:
if name in done: continue
done[name] = 1
if value is not __UNDEF__:
if where in ('global', 'builtin'):
name = ('<em>%s</em> ' % where) + strong(name)
elif where == 'local':
name = strong(name)
else:
name = where + strong(name.split('.')[-1])
dump.append('%s&nbsp;= %s' % (name, pydoc.html.repr(value)))
else:
dump.append(name + ' <em>undefined</em>')
rows.append('<tr><td>%s</td></tr>' % small(grey(', '.join(dump))))
frames.append('''
<table width="100%%" cellspacing=0 cellpadding=0 border=0>
%s</table>''' % '\n'.join(rows))
exception = ['<p>%s: %s' % (strong(pydoc.html.escape(str(etype))),
pydoc.html.escape(str(evalue)))]
if isinstance(evalue, BaseException):
for name in dir(evalue):
if name[:1] == '_': continue
value = pydoc.html.repr(getattr(evalue, name))
exception.append('\n<br>%s%s&nbsp;=\n%s' % (indent, name, value))
return head + ''.join(frames) + ''.join(exception) + '''
<!-- The above is a description of an error in a Python program, formatted
for a Web browser because the 'cgitb' module was enabled. In case you
are not reading this in a Web browser, here is the original traceback:
%s
-->
''' % pydoc.html.escape(
''.join(traceback.format_exception(etype, evalue, etb)))
def text(einfo, context=5):
"""Return a plain text document describing a given traceback."""
etype, evalue, etb = einfo
if type(etype) is types.ClassType:
etype = etype.__name__
pyver = 'Python ' + sys.version.split()[0] + ': ' + sys.executable
date = time.ctime(time.time())
head = "%s\n%s\n%s\n" % (str(etype), pyver, date) + '''
A problem occurred in a Python script. Here is the sequence of
function calls leading up to the error, in the order they occurred.
'''
frames = []
records = inspect.getinnerframes(etb, context)
for frame, file, lnum, func, lines, index in records:
file = file and os.path.abspath(file) or '?'
args, varargs, varkw, locals = inspect.getargvalues(frame)
call = ''
if func != '?':
call = 'in ' + func + \
inspect.formatargvalues(args, varargs, varkw, locals,
formatvalue=lambda value: '=' + pydoc.text.repr(value))
highlight = {}
def reader(lnum=[lnum]):
highlight[lnum[0]] = 1
try: return linecache.getline(file, lnum[0])
finally: lnum[0] += 1
vars = scanvars(reader, frame, locals)
rows = [' %s %s' % (file, call)]
if index is not None:
i = lnum - index
for line in lines:
num = '%5d ' % i
rows.append(num+line.rstrip())
i += 1
done, dump = {}, []
for name, where, value in vars:
if name in done: continue
done[name] = 1
if value is not __UNDEF__:
if where == 'global': name = 'global ' + name
elif where != 'local': name = where + name.split('.')[-1]
dump.append('%s = %s' % (name, pydoc.text.repr(value)))
else:
dump.append(name + ' undefined')
rows.append('\n'.join(dump))
frames.append('\n%s\n' % '\n'.join(rows))
exception = ['%s: %s' % (str(etype), str(evalue))]
if isinstance(evalue, BaseException):
for name in dir(evalue):
value = pydoc.text.repr(getattr(evalue, name))
exception.append('\n%s%s = %s' % (" "*4, name, value))
return head + ''.join(frames) + ''.join(exception) + '''
The above is a description of an error in a Python program. Here is
the original traceback:
%s
''' % ''.join(traceback.format_exception(etype, evalue, etb))
class Hook:
"""A hook to replace sys.excepthook that shows tracebacks in HTML."""
def __init__(self, display=1, logdir=None, context=5, file=None,
format="html"):
self.display = display # send tracebacks to browser if true
self.logdir = logdir # log tracebacks to files if not None
self.context = context # number of source code lines per frame
self.file = file or sys.stdout # place to send the output
self.format = format
def __call__(self, etype, evalue, etb):
self.handle((etype, evalue, etb))
def handle(self, info=None):
info = info or sys.exc_info()
if self.format == "html":
self.file.write(reset())
formatter = (self.format=="html") and html or text
plain = False
try:
doc = formatter(info, self.context)
except: # just in case something goes wrong
doc = ''.join(traceback.format_exception(*info))
plain = True
if self.display:
if plain:
doc = doc.replace('&', '&amp;').replace('<', '&lt;')
self.file.write('<pre>' + doc + '</pre>\n')
else:
self.file.write(doc + '\n')
else:
self.file.write('<p>A problem occurred in a Python script.\n')
if self.logdir is not None:
suffix = ['.txt', '.html'][self.format=="html"]
(fd, path) = tempfile.mkstemp(suffix=suffix, dir=self.logdir)
try:
file = os.fdopen(fd, 'w')
file.write(doc)
file.close()
msg = '<p> %s contains the description of this error.' % path
except:
msg = '<p> Tried to save traceback to %s, but failed.' % path
self.file.write(msg + '\n')
try:
self.file.flush()
except: pass
handler = Hook().handle
def enable(display=1, logdir=None, context=5, format="html"):
"""Install an exception handler that formats tracebacks as HTML.
The optional argument 'display' can be set to 0 to suppress sending the
traceback to the browser, and 'logdir' can be set to a directory to cause
tracebacks to be written to files there."""
sys.excepthook = Hook(display=display, logdir=logdir,
context=context, format=format)

167
src/python/stdlib/chunk.py Normal file
View File

@@ -0,0 +1,167 @@
"""Simple class to read IFF chunks.
An IFF chunk (used in formats such as AIFF, TIFF, RMFF (RealMedia File
Format)) has the following structure:
+----------------+
| ID (4 bytes) |
+----------------+
| size (4 bytes) |
+----------------+
| data |
| ... |
+----------------+
The ID is a 4-byte string which identifies the type of chunk.
The size field (a 32-bit value, encoded using big-endian byte order)
gives the size of the whole chunk, including the 8-byte header.
Usually an IFF-type file consists of one or more chunks. The proposed
usage of the Chunk class defined here is to instantiate an instance at
the start of each chunk and read from the instance until it reaches
the end, after which a new instance can be instantiated. At the end
of the file, creating a new instance will fail with a EOFError
exception.
Usage:
while True:
try:
chunk = Chunk(file)
except EOFError:
break
chunktype = chunk.getname()
while True:
data = chunk.read(nbytes)
if not data:
pass
# do something with data
The interface is file-like. The implemented methods are:
read, close, seek, tell, isatty.
Extra methods are: skip() (called by close, skips to the end of the chunk),
getname() (returns the name (ID) of the chunk)
The __init__ method has one required argument, a file-like object
(including a chunk instance), and one optional argument, a flag which
specifies whether or not chunks are aligned on 2-byte boundaries. The
default is 1, i.e. aligned.
"""
class Chunk:
def __init__(self, file, align=True, bigendian=True, inclheader=False):
import struct
self.closed = False
self.align = align # whether to align to word (2-byte) boundaries
if bigendian:
strflag = '>'
else:
strflag = '<'
self.file = file
self.chunkname = file.read(4)
if len(self.chunkname) < 4:
raise EOFError
try:
self.chunksize = struct.unpack(strflag+'L', file.read(4))[0]
except struct.error:
raise EOFError
if inclheader:
self.chunksize = self.chunksize - 8 # subtract header
self.size_read = 0
try:
self.offset = self.file.tell()
except (AttributeError, IOError):
self.seekable = False
else:
self.seekable = True
def getname(self):
"""Return the name (ID) of the current chunk."""
return self.chunkname
def getsize(self):
"""Return the size of the current chunk."""
return self.chunksize
def close(self):
if not self.closed:
self.skip()
self.closed = True
def isatty(self):
if self.closed:
raise ValueError, "I/O operation on closed file"
return False
def seek(self, pos, whence=0):
"""Seek to specified position into the chunk.
Default position is 0 (start of chunk).
If the file is not seekable, this will result in an error.
"""
if self.closed:
raise ValueError, "I/O operation on closed file"
if not self.seekable:
raise IOError, "cannot seek"
if whence == 1:
pos = pos + self.size_read
elif whence == 2:
pos = pos + self.chunksize
if pos < 0 or pos > self.chunksize:
raise RuntimeError
self.file.seek(self.offset + pos, 0)
self.size_read = pos
def tell(self):
if self.closed:
raise ValueError, "I/O operation on closed file"
return self.size_read
def read(self, size=-1):
"""Read at most size bytes from the chunk.
If size is omitted or negative, read until the end
of the chunk.
"""
if self.closed:
raise ValueError, "I/O operation on closed file"
if self.size_read >= self.chunksize:
return ''
if size < 0:
size = self.chunksize - self.size_read
if size > self.chunksize - self.size_read:
size = self.chunksize - self.size_read
data = self.file.read(size)
self.size_read = self.size_read + len(data)
if self.size_read == self.chunksize and \
self.align and \
(self.chunksize & 1):
dummy = self.file.read(1)
self.size_read = self.size_read + len(dummy)
return data
def skip(self):
"""Skip the rest of the chunk.
If you are not interested in the contents of the chunk,
this method should be called so that the file points to
the start of the next chunk.
"""
if self.closed:
raise ValueError, "I/O operation on closed file"
if self.seekable:
try:
n = self.chunksize - self.size_read
# maybe fix alignment
if self.align and (self.chunksize & 1):
n = n + 1
self.file.seek(n, 1)
self.size_read = self.size_read + n
return
except IOError:
pass
while self.size_read < self.chunksize:
n = min(8192, self.chunksize - self.size_read)
dummy = self.read(n)
if not dummy:
raise EOFError

401
src/python/stdlib/cmd.py Normal file
View File

@@ -0,0 +1,401 @@
"""A generic class to build line-oriented command interpreters.
Interpreters constructed with this class obey the following conventions:
1. End of file on input is processed as the command 'EOF'.
2. A command is parsed out of each line by collecting the prefix composed
of characters in the identchars member.
3. A command `foo' is dispatched to a method 'do_foo()'; the do_ method
is passed a single argument consisting of the remainder of the line.
4. Typing an empty line repeats the last command. (Actually, it calls the
method `emptyline', which may be overridden in a subclass.)
5. There is a predefined `help' method. Given an argument `topic', it
calls the command `help_topic'. With no arguments, it lists all topics
with defined help_ functions, broken into up to three topics; documented
commands, miscellaneous help topics, and undocumented commands.
6. The command '?' is a synonym for `help'. The command '!' is a synonym
for `shell', if a do_shell method exists.
7. If completion is enabled, completing commands will be done automatically,
and completing of commands args is done by calling complete_foo() with
arguments text, line, begidx, endidx. text is string we are matching
against, all returned matches must begin with it. line is the current
input line (lstripped), begidx and endidx are the beginning and end
indexes of the text being matched, which could be used to provide
different completion depending upon which position the argument is in.
The `default' method may be overridden to intercept commands for which there
is no do_ method.
The `completedefault' method may be overridden to intercept completions for
commands that have no complete_ method.
The data member `self.ruler' sets the character used to draw separator lines
in the help messages. If empty, no ruler line is drawn. It defaults to "=".
If the value of `self.intro' is nonempty when the cmdloop method is called,
it is printed out on interpreter startup. This value may be overridden
via an optional argument to the cmdloop() method.
The data members `self.doc_header', `self.misc_header', and
`self.undoc_header' set the headers used for the help function's
listings of documented functions, miscellaneous topics, and undocumented
functions respectively.
These interpreters use raw_input; thus, if the readline module is loaded,
they automatically support Emacs-like command history and editing features.
"""
import string
__all__ = ["Cmd"]
PROMPT = '(Cmd) '
IDENTCHARS = string.ascii_letters + string.digits + '_'
class Cmd:
"""A simple framework for writing line-oriented command interpreters.
These are often useful for test harnesses, administrative tools, and
prototypes that will later be wrapped in a more sophisticated interface.
A Cmd instance or subclass instance is a line-oriented interpreter
framework. There is no good reason to instantiate Cmd itself; rather,
it's useful as a superclass of an interpreter class you define yourself
in order to inherit Cmd's methods and encapsulate action methods.
"""
prompt = PROMPT
identchars = IDENTCHARS
ruler = '='
lastcmd = ''
intro = None
doc_leader = ""
doc_header = "Documented commands (type help <topic>):"
misc_header = "Miscellaneous help topics:"
undoc_header = "Undocumented commands:"
nohelp = "*** No help on %s"
use_rawinput = 1
def __init__(self, completekey='tab', stdin=None, stdout=None):
"""Instantiate a line-oriented interpreter framework.
The optional argument 'completekey' is the readline name of a
completion key; it defaults to the Tab key. If completekey is
not None and the readline module is available, command completion
is done automatically. The optional arguments stdin and stdout
specify alternate input and output file objects; if not specified,
sys.stdin and sys.stdout are used.
"""
import sys
if stdin is not None:
self.stdin = stdin
else:
self.stdin = sys.stdin
if stdout is not None:
self.stdout = stdout
else:
self.stdout = sys.stdout
self.cmdqueue = []
self.completekey = completekey
def cmdloop(self, intro=None):
"""Repeatedly issue a prompt, accept input, parse an initial prefix
off the received input, and dispatch to action methods, passing them
the remainder of the line as argument.
"""
self.preloop()
if self.use_rawinput and self.completekey:
try:
import readline
self.old_completer = readline.get_completer()
readline.set_completer(self.complete)
readline.parse_and_bind(self.completekey+": complete")
except ImportError:
pass
try:
if intro is not None:
self.intro = intro
if self.intro:
self.stdout.write(str(self.intro)+"\n")
stop = None
while not stop:
if self.cmdqueue:
line = self.cmdqueue.pop(0)
else:
if self.use_rawinput:
try:
line = raw_input(self.prompt)
except EOFError:
line = 'EOF'
else:
self.stdout.write(self.prompt)
self.stdout.flush()
line = self.stdin.readline()
if not len(line):
line = 'EOF'
else:
line = line.rstrip('\r\n')
line = self.precmd(line)
stop = self.onecmd(line)
stop = self.postcmd(stop, line)
self.postloop()
finally:
if self.use_rawinput and self.completekey:
try:
import readline
readline.set_completer(self.old_completer)
except ImportError:
pass
def precmd(self, line):
"""Hook method executed just before the command line is
interpreted, but after the input prompt is generated and issued.
"""
return line
def postcmd(self, stop, line):
"""Hook method executed just after a command dispatch is finished."""
return stop
def preloop(self):
"""Hook method executed once when the cmdloop() method is called."""
pass
def postloop(self):
"""Hook method executed once when the cmdloop() method is about to
return.
"""
pass
def parseline(self, line):
"""Parse the line into a command name and a string containing
the arguments. Returns a tuple containing (command, args, line).
'command' and 'args' may be None if the line couldn't be parsed.
"""
line = line.strip()
if not line:
return None, None, line
elif line[0] == '?':
line = 'help ' + line[1:]
elif line[0] == '!':
if hasattr(self, 'do_shell'):
line = 'shell ' + line[1:]
else:
return None, None, line
i, n = 0, len(line)
while i < n and line[i] in self.identchars: i = i+1
cmd, arg = line[:i], line[i:].strip()
return cmd, arg, line
def onecmd(self, line):
"""Interpret the argument as though it had been typed in response
to the prompt.
This may be overridden, but should not normally need to be;
see the precmd() and postcmd() methods for useful execution hooks.
The return value is a flag indicating whether interpretation of
commands by the interpreter should stop.
"""
cmd, arg, line = self.parseline(line)
if not line:
return self.emptyline()
if cmd is None:
return self.default(line)
self.lastcmd = line
if cmd == '':
return self.default(line)
else:
try:
func = getattr(self, 'do_' + cmd)
except AttributeError:
return self.default(line)
return func(arg)
def emptyline(self):
"""Called when an empty line is entered in response to the prompt.
If this method is not overridden, it repeats the last nonempty
command entered.
"""
if self.lastcmd:
return self.onecmd(self.lastcmd)
def default(self, line):
"""Called on an input line when the command prefix is not recognized.
If this method is not overridden, it prints an error message and
returns.
"""
self.stdout.write('*** Unknown syntax: %s\n'%line)
def completedefault(self, *ignored):
"""Method called to complete an input line when no command-specific
complete_*() method is available.
By default, it returns an empty list.
"""
return []
def completenames(self, text, *ignored):
dotext = 'do_'+text
return [a[3:] for a in self.get_names() if a.startswith(dotext)]
def complete(self, text, state):
"""Return the next possible completion for 'text'.
If a command has not been entered, then complete against command list.
Otherwise try to call complete_<command> to get list of completions.
"""
if state == 0:
import readline
origline = readline.get_line_buffer()
line = origline.lstrip()
stripped = len(origline) - len(line)
begidx = readline.get_begidx() - stripped
endidx = readline.get_endidx() - stripped
if begidx>0:
cmd, args, foo = self.parseline(line)
if cmd == '':
compfunc = self.completedefault
else:
try:
compfunc = getattr(self, 'complete_' + cmd)
except AttributeError:
compfunc = self.completedefault
else:
compfunc = self.completenames
self.completion_matches = compfunc(text, line, begidx, endidx)
try:
return self.completion_matches[state]
except IndexError:
return None
def get_names(self):
# This method used to pull in base class attributes
# at a time dir() didn't do it yet.
return dir(self.__class__)
def complete_help(self, *args):
commands = set(self.completenames(*args))
topics = set(a[5:] for a in self.get_names()
if a.startswith('help_' + args[0]))
return list(commands | topics)
def do_help(self, arg):
if arg:
# XXX check arg syntax
try:
func = getattr(self, 'help_' + arg)
except AttributeError:
try:
doc=getattr(self, 'do_' + arg).__doc__
if doc:
self.stdout.write("%s\n"%str(doc))
return
except AttributeError:
pass
self.stdout.write("%s\n"%str(self.nohelp % (arg,)))
return
func()
else:
names = self.get_names()
cmds_doc = []
cmds_undoc = []
help = {}
for name in names:
if name[:5] == 'help_':
help[name[5:]]=1
names.sort()
# There can be duplicates if routines overridden
prevname = ''
for name in names:
if name[:3] == 'do_':
if name == prevname:
continue
prevname = name
cmd=name[3:]
if cmd in help:
cmds_doc.append(cmd)
del help[cmd]
elif getattr(self, name).__doc__:
cmds_doc.append(cmd)
else:
cmds_undoc.append(cmd)
self.stdout.write("%s\n"%str(self.doc_leader))
self.print_topics(self.doc_header, cmds_doc, 15,80)
self.print_topics(self.misc_header, help.keys(),15,80)
self.print_topics(self.undoc_header, cmds_undoc, 15,80)
def print_topics(self, header, cmds, cmdlen, maxcol):
if cmds:
self.stdout.write("%s\n"%str(header))
if self.ruler:
self.stdout.write("%s\n"%str(self.ruler * len(header)))
self.columnize(cmds, maxcol-1)
self.stdout.write("\n")
def columnize(self, list, displaywidth=80):
"""Display a list of strings as a compact set of columns.
Each column is only as wide as necessary.
Columns are separated by two spaces (one was not legible enough).
"""
if not list:
self.stdout.write("<empty>\n")
return
nonstrings = [i for i in range(len(list))
if not isinstance(list[i], str)]
if nonstrings:
raise TypeError, ("list[i] not a string for i in %s" %
", ".join(map(str, nonstrings)))
size = len(list)
if size == 1:
self.stdout.write('%s\n'%str(list[0]))
return
# Try every row count from 1 upwards
for nrows in range(1, len(list)):
ncols = (size+nrows-1) // nrows
colwidths = []
totwidth = -2
for col in range(ncols):
colwidth = 0
for row in range(nrows):
i = row + nrows*col
if i >= size:
break
x = list[i]
colwidth = max(colwidth, len(x))
colwidths.append(colwidth)
totwidth += colwidth + 2
if totwidth > displaywidth:
break
if totwidth <= displaywidth:
break
else:
nrows = len(list)
ncols = 1
colwidths = [0]
for row in range(nrows):
texts = []
for col in range(ncols):
i = row + nrows*col
if i >= size:
x = ""
else:
x = list[i]
texts.append(x)
while texts and not texts[-1]:
del texts[-1]
for col in range(len(texts)):
texts[col] = texts[col].ljust(colwidths[col])
self.stdout.write("%s\n"%str(" ".join(texts)))

310
src/python/stdlib/code.py Normal file
View File

@@ -0,0 +1,310 @@
"""Utilities needed to emulate Python's interactive interpreter.
"""
# Inspired by similar code by Jeff Epler and Fredrik Lundh.
import sys
import traceback
from codeop import CommandCompiler, compile_command
__all__ = ["InteractiveInterpreter", "InteractiveConsole", "interact",
"compile_command"]
def softspace(file, newvalue):
oldvalue = 0
try:
oldvalue = file.softspace
except AttributeError:
pass
try:
file.softspace = newvalue
except (AttributeError, TypeError):
# "attribute-less object" or "read-only attributes"
pass
return oldvalue
class InteractiveInterpreter:
"""Base class for InteractiveConsole.
This class deals with parsing and interpreter state (the user's
namespace); it doesn't deal with input buffering or prompting or
input file naming (the filename is always passed in explicitly).
"""
def __init__(self, locals=None):
"""Constructor.
The optional 'locals' argument specifies the dictionary in
which code will be executed; it defaults to a newly created
dictionary with key "__name__" set to "__console__" and key
"__doc__" set to None.
"""
if locals is None:
locals = {"__name__": "__console__", "__doc__": None}
self.locals = locals
self.compile = CommandCompiler()
def runsource(self, source, filename="<input>", symbol="single"):
"""Compile and run some source in the interpreter.
Arguments are as for compile_command().
One several things can happen:
1) The input is incorrect; compile_command() raised an
exception (SyntaxError or OverflowError). A syntax traceback
will be printed by calling the showsyntaxerror() method.
2) The input is incomplete, and more input is required;
compile_command() returned None. Nothing happens.
3) The input is complete; compile_command() returned a code
object. The code is executed by calling self.runcode() (which
also handles run-time exceptions, except for SystemExit).
The return value is True in case 2, False in the other cases (unless
an exception is raised). The return value can be used to
decide whether to use sys.ps1 or sys.ps2 to prompt the next
line.
"""
try:
code = self.compile(source, filename, symbol)
except (OverflowError, SyntaxError, ValueError):
# Case 1
self.showsyntaxerror(filename)
return False
if code is None:
# Case 2
return True
# Case 3
self.runcode(code)
return False
def runcode(self, code):
"""Execute a code object.
When an exception occurs, self.showtraceback() is called to
display a traceback. All exceptions are caught except
SystemExit, which is reraised.
A note about KeyboardInterrupt: this exception may occur
elsewhere in this code, and may not always be caught. The
caller should be prepared to deal with it.
"""
try:
exec code in self.locals
except SystemExit:
raise
except:
self.showtraceback()
else:
if softspace(sys.stdout, 0):
print
def showsyntaxerror(self, filename=None):
"""Display the syntax error that just occurred.
This doesn't display a stack trace because there isn't one.
If a filename is given, it is stuffed in the exception instead
of what was there before (because Python's parser always uses
"<string>" when reading from a string).
The output is written by self.write(), below.
"""
type, value, sys.last_traceback = sys.exc_info()
sys.last_type = type
sys.last_value = value
if filename and type is SyntaxError:
# Work hard to stuff the correct filename in the exception
try:
msg, (dummy_filename, lineno, offset, line) = value
except:
# Not the format we expect; leave it alone
pass
else:
# Stuff in the right filename
value = SyntaxError(msg, (filename, lineno, offset, line))
sys.last_value = value
list = traceback.format_exception_only(type, value)
map(self.write, list)
def showtraceback(self):
"""Display the exception that just occurred.
We remove the first stack item because it is our own code.
The output is written by self.write(), below.
"""
try:
type, value, tb = sys.exc_info()
sys.last_type = type
sys.last_value = value
sys.last_traceback = tb
tblist = traceback.extract_tb(tb)
del tblist[:1]
list = traceback.format_list(tblist)
if list:
list.insert(0, "Traceback (most recent call last):\n")
list[len(list):] = traceback.format_exception_only(type, value)
finally:
tblist = tb = None
map(self.write, list)
def write(self, data):
"""Write a string.
The base implementation writes to sys.stderr; a subclass may
replace this with a different implementation.
"""
sys.stderr.write(data)
class InteractiveConsole(InteractiveInterpreter):
"""Closely emulate the behavior of the interactive Python interpreter.
This class builds on InteractiveInterpreter and adds prompting
using the familiar sys.ps1 and sys.ps2, and input buffering.
"""
def __init__(self, locals=None, filename="<console>"):
"""Constructor.
The optional locals argument will be passed to the
InteractiveInterpreter base class.
The optional filename argument should specify the (file)name
of the input stream; it will show up in tracebacks.
"""
InteractiveInterpreter.__init__(self, locals)
self.filename = filename
self.resetbuffer()
def resetbuffer(self):
"""Reset the input buffer."""
self.buffer = []
def interact(self, banner=None):
"""Closely emulate the interactive Python console.
The optional banner argument specify the banner to print
before the first interaction; by default it prints a banner
similar to the one printed by the real Python interpreter,
followed by the current class name in parentheses (so as not
to confuse this with the real interpreter -- since it's so
close!).
"""
try:
sys.ps1
except AttributeError:
sys.ps1 = ">>> "
try:
sys.ps2
except AttributeError:
sys.ps2 = "... "
cprt = 'Type "help", "copyright", "credits" or "license" for more information.'
if banner is None:
self.write("Python %s on %s\n%s\n(%s)\n" %
(sys.version, sys.platform, cprt,
self.__class__.__name__))
else:
self.write("%s\n" % str(banner))
more = 0
while 1:
try:
if more:
prompt = sys.ps2
else:
prompt = sys.ps1
try:
line = self.raw_input(prompt)
# Can be None if sys.stdin was redefined
encoding = getattr(sys.stdin, "encoding", None)
if encoding and not isinstance(line, unicode):
line = line.decode(encoding)
except EOFError:
self.write("\n")
break
else:
more = self.push(line)
except KeyboardInterrupt:
self.write("\nKeyboardInterrupt\n")
self.resetbuffer()
more = 0
def push(self, line):
"""Push a line to the interpreter.
The line should not have a trailing newline; it may have
internal newlines. The line is appended to a buffer and the
interpreter's runsource() method is called with the
concatenated contents of the buffer as source. If this
indicates that the command was executed or invalid, the buffer
is reset; otherwise, the command is incomplete, and the buffer
is left as it was after the line was appended. The return
value is 1 if more input is required, 0 if the line was dealt
with in some way (this is the same as runsource()).
"""
self.buffer.append(line)
source = "\n".join(self.buffer)
more = self.runsource(source, self.filename)
if not more:
self.resetbuffer()
return more
def raw_input(self, prompt=""):
"""Write a prompt and read a line.
The returned line does not include the trailing newline.
When the user enters the EOF key sequence, EOFError is raised.
The base implementation uses the built-in function
raw_input(); a subclass may replace this with a different
implementation.
"""
return raw_input(prompt)
def interact(banner=None, readfunc=None, local=None):
"""Closely emulate the interactive Python interpreter.
This is a backwards compatible interface to the InteractiveConsole
class. When readfunc is not specified, it attempts to import the
readline module to enable GNU readline if it is available.
Arguments (all optional, all default to None):
banner -- passed to InteractiveConsole.interact()
readfunc -- if not None, replaces InteractiveConsole.raw_input()
local -- passed to InteractiveInterpreter.__init__()
"""
console = InteractiveConsole(local)
if readfunc is not None:
console.raw_input = readfunc
else:
try:
import readline
except ImportError:
pass
console.interact(banner)
if __name__ == "__main__":
interact()

1098
src/python/stdlib/codecs.py Normal file

File diff suppressed because it is too large Load Diff

168
src/python/stdlib/codeop.py Normal file
View File

@@ -0,0 +1,168 @@
r"""Utilities to compile possibly incomplete Python source code.
This module provides two interfaces, broadly similar to the builtin
function compile(), which take program text, a filename and a 'mode'
and:
- Return code object if the command is complete and valid
- Return None if the command is incomplete
- Raise SyntaxError, ValueError or OverflowError if the command is a
syntax error (OverflowError and ValueError can be produced by
malformed literals).
Approach:
First, check if the source consists entirely of blank lines and
comments; if so, replace it with 'pass', because the built-in
parser doesn't always do the right thing for these.
Compile three times: as is, with \n, and with \n\n appended. If it
compiles as is, it's complete. If it compiles with one \n appended,
we expect more. If it doesn't compile either way, we compare the
error we get when compiling with \n or \n\n appended. If the errors
are the same, the code is broken. But if the errors are different, we
expect more. Not intuitive; not even guaranteed to hold in future
releases; but this matches the compiler's behavior from Python 1.4
through 2.2, at least.
Caveat:
It is possible (but not likely) that the parser stops parsing with a
successful outcome before reaching the end of the source; in this
case, trailing symbols may be ignored instead of causing an error.
For example, a backslash followed by two newlines may be followed by
arbitrary garbage. This will be fixed once the API for the parser is
better.
The two interfaces are:
compile_command(source, filename, symbol):
Compiles a single command in the manner described above.
CommandCompiler():
Instances of this class have __call__ methods identical in
signature to compile_command; the difference is that if the
instance compiles program text containing a __future__ statement,
the instance 'remembers' and compiles all subsequent program texts
with the statement in force.
The module also provides another class:
Compile():
Instances of this class act like the built-in function compile,
but with 'memory' in the sense described above.
"""
import __future__
_features = [getattr(__future__, fname)
for fname in __future__.all_feature_names]
__all__ = ["compile_command", "Compile", "CommandCompiler"]
PyCF_DONT_IMPLY_DEDENT = 0x200 # Matches pythonrun.h
def _maybe_compile(compiler, source, filename, symbol):
# Check for source consisting of only blank lines and comments
for line in source.split("\n"):
line = line.strip()
if line and line[0] != '#':
break # Leave it alone
else:
if symbol != "eval":
source = "pass" # Replace it with a 'pass' statement
err = err1 = err2 = None
code = code1 = code2 = None
try:
code = compiler(source, filename, symbol)
except SyntaxError, err:
pass
try:
code1 = compiler(source + "\n", filename, symbol)
except SyntaxError, err1:
pass
try:
code2 = compiler(source + "\n\n", filename, symbol)
except SyntaxError, err2:
pass
if code:
return code
if not code1 and repr(err1) == repr(err2):
raise SyntaxError, err1
def _compile(source, filename, symbol):
return compile(source, filename, symbol, PyCF_DONT_IMPLY_DEDENT)
def compile_command(source, filename="<input>", symbol="single"):
r"""Compile a command and determine whether it is incomplete.
Arguments:
source -- the source string; may contain \n characters
filename -- optional filename from which source was read; default
"<input>"
symbol -- optional grammar start symbol; "single" (default) or "eval"
Return value / exceptions raised:
- Return a code object if the command is complete and valid
- Return None if the command is incomplete
- Raise SyntaxError, ValueError or OverflowError if the command is a
syntax error (OverflowError and ValueError can be produced by
malformed literals).
"""
return _maybe_compile(_compile, source, filename, symbol)
class Compile:
"""Instances of this class behave much like the built-in compile
function, but if one is used to compile text containing a future
statement, it "remembers" and compiles all subsequent program texts
with the statement in force."""
def __init__(self):
self.flags = PyCF_DONT_IMPLY_DEDENT
def __call__(self, source, filename, symbol):
codeob = compile(source, filename, symbol, self.flags, 1)
for feature in _features:
if codeob.co_flags & feature.compiler_flag:
self.flags |= feature.compiler_flag
return codeob
class CommandCompiler:
"""Instances of this class have __call__ methods identical in
signature to compile_command; the difference is that if the
instance compiles program text containing a __future__ statement,
the instance 'remembers' and compiles all subsequent program texts
with the statement in force."""
def __init__(self,):
self.compiler = Compile()
def __call__(self, source, filename="<input>", symbol="single"):
r"""Compile a command and determine whether it is incomplete.
Arguments:
source -- the source string; may contain \n characters
filename -- optional filename from which source was read;
default "<input>"
symbol -- optional grammar start symbol; "single" (default) or
"eval"
Return value / exceptions raised:
- Return a code object if the command is complete and valid
- Return None if the command is incomplete
- Raise SyntaxError, ValueError or OverflowError if the command is a
syntax error (OverflowError and ValueError can be produced by
malformed literals).
"""
return _maybe_compile(self.compiler, source, filename, symbol)

View File

@@ -0,0 +1,627 @@
__all__ = ['Counter', 'deque', 'defaultdict', 'namedtuple', 'OrderedDict']
# For bootstrapping reasons, the collection ABCs are defined in _abcoll.py.
# They should however be considered an integral part of collections.py.
from _abcoll import *
import _abcoll
__all__ += _abcoll.__all__
from _collections import deque, defaultdict
from operator import itemgetter as _itemgetter, eq as _eq
from keyword import iskeyword as _iskeyword
import sys as _sys
import heapq as _heapq
from itertools import repeat as _repeat, chain as _chain, starmap as _starmap, \
ifilter as _ifilter, imap as _imap
try:
from thread import get_ident
except ImportError:
from dummy_thread import get_ident
def _recursive_repr(user_function):
'Decorator to make a repr function return "..." for a recursive call'
repr_running = set()
def wrapper(self):
key = id(self), get_ident()
if key in repr_running:
return '...'
repr_running.add(key)
try:
result = user_function(self)
finally:
repr_running.discard(key)
return result
# Can't use functools.wraps() here because of bootstrap issues
wrapper.__module__ = getattr(user_function, '__module__')
wrapper.__doc__ = getattr(user_function, '__doc__')
wrapper.__name__ = getattr(user_function, '__name__')
return wrapper
################################################################################
### OrderedDict
################################################################################
class OrderedDict(dict, MutableMapping):
'Dictionary that remembers insertion order'
# An inherited dict maps keys to values.
# The inherited dict provides __getitem__, __len__, __contains__, and get.
# The remaining methods are order-aware.
# Big-O running times for all methods are the same as for regular dictionaries.
# The internal self.__map dictionary maps keys to links in a doubly linked list.
# The circular doubly linked list starts and ends with a sentinel element.
# The sentinel element never gets deleted (this simplifies the algorithm).
# Each link is stored as a list of length three: [PREV, NEXT, KEY].
def __init__(self, *args, **kwds):
'''Initialize an ordered dictionary. Signature is the same as for
regular dictionaries, but keyword arguments are not recommended
because their insertion order is arbitrary.
'''
if len(args) > 1:
raise TypeError('expected at most 1 arguments, got %d' % len(args))
try:
self.__root
except AttributeError:
self.__root = root = [None, None, None] # sentinel node
PREV = 0
NEXT = 1
root[PREV] = root[NEXT] = root
self.__map = {}
self.update(*args, **kwds)
def __setitem__(self, key, value, PREV=0, NEXT=1, dict_setitem=dict.__setitem__):
'od.__setitem__(i, y) <==> od[i]=y'
# Setting a new item creates a new link which goes at the end of the linked
# list, and the inherited dictionary is updated with the new key/value pair.
if key not in self:
root = self.__root
last = root[PREV]
last[NEXT] = root[PREV] = self.__map[key] = [last, root, key]
dict_setitem(self, key, value)
def __delitem__(self, key, PREV=0, NEXT=1, dict_delitem=dict.__delitem__):
'od.__delitem__(y) <==> del od[y]'
# Deleting an existing item uses self.__map to find the link which is
# then removed by updating the links in the predecessor and successor nodes.
dict_delitem(self, key)
link = self.__map.pop(key)
link_prev = link[PREV]
link_next = link[NEXT]
link_prev[NEXT] = link_next
link_next[PREV] = link_prev
def __iter__(self, NEXT=1, KEY=2):
'od.__iter__() <==> iter(od)'
# Traverse the linked list in order.
root = self.__root
curr = root[NEXT]
while curr is not root:
yield curr[KEY]
curr = curr[NEXT]
def __reversed__(self, PREV=0, KEY=2):
'od.__reversed__() <==> reversed(od)'
# Traverse the linked list in reverse order.
root = self.__root
curr = root[PREV]
while curr is not root:
yield curr[KEY]
curr = curr[PREV]
def __reduce__(self):
'Return state information for pickling'
items = [[k, self[k]] for k in self]
tmp = self.__map, self.__root
del self.__map, self.__root
inst_dict = vars(self).copy()
self.__map, self.__root = tmp
if inst_dict:
return (self.__class__, (items,), inst_dict)
return self.__class__, (items,)
def clear(self):
'od.clear() -> None. Remove all items from od.'
try:
for node in self.__map.itervalues():
del node[:]
self.__root[:] = [self.__root, self.__root, None]
self.__map.clear()
except AttributeError:
pass
dict.clear(self)
setdefault = MutableMapping.setdefault
update = MutableMapping.update
pop = MutableMapping.pop
keys = MutableMapping.keys
values = MutableMapping.values
items = MutableMapping.items
iterkeys = MutableMapping.iterkeys
itervalues = MutableMapping.itervalues
iteritems = MutableMapping.iteritems
__ne__ = MutableMapping.__ne__
def viewkeys(self):
"od.viewkeys() -> a set-like object providing a view on od's keys"
return KeysView(self)
def viewvalues(self):
"od.viewvalues() -> an object providing a view on od's values"
return ValuesView(self)
def viewitems(self):
"od.viewitems() -> a set-like object providing a view on od's items"
return ItemsView(self)
def popitem(self, last=True):
'''od.popitem() -> (k, v), return and remove a (key, value) pair.
Pairs are returned in LIFO order if last is true or FIFO order if false.
'''
if not self:
raise KeyError('dictionary is empty')
key = next(reversed(self) if last else iter(self))
value = self.pop(key)
return key, value
@_recursive_repr
def __repr__(self):
'od.__repr__() <==> repr(od)'
if not self:
return '%s()' % (self.__class__.__name__,)
return '%s(%r)' % (self.__class__.__name__, self.items())
def copy(self):
'od.copy() -> a shallow copy of od'
return self.__class__(self)
@classmethod
def fromkeys(cls, iterable, value=None):
'''OD.fromkeys(S[, v]) -> New ordered dictionary with keys from S
and values equal to v (which defaults to None).
'''
d = cls()
for key in iterable:
d[key] = value
return d
def __eq__(self, other):
'''od.__eq__(y) <==> od==y. Comparison to another OD is order-sensitive
while comparison to a regular mapping is order-insensitive.
'''
if isinstance(other, OrderedDict):
return len(self)==len(other) and \
all(_imap(_eq, self.iteritems(), other.iteritems()))
return dict.__eq__(self, other)
################################################################################
### namedtuple
################################################################################
def namedtuple(typename, field_names, verbose=False, rename=False):
"""Returns a new subclass of tuple with named fields.
>>> Point = namedtuple('Point', 'x y')
>>> Point.__doc__ # docstring for the new class
'Point(x, y)'
>>> p = Point(11, y=22) # instantiate with positional args or keywords
>>> p[0] + p[1] # indexable like a plain tuple
33
>>> x, y = p # unpack like a regular tuple
>>> x, y
(11, 22)
>>> p.x + p.y # fields also accessable by name
33
>>> d = p._asdict() # convert to a dictionary
>>> d['x']
11
>>> Point(**d) # convert from a dictionary
Point(x=11, y=22)
>>> p._replace(x=100) # _replace() is like str.replace() but targets named fields
Point(x=100, y=22)
"""
# Parse and validate the field names. Validation serves two purposes,
# generating informative error messages and preventing template injection attacks.
if isinstance(field_names, basestring):
field_names = field_names.replace(',', ' ').split() # names separated by whitespace and/or commas
field_names = tuple(map(str, field_names))
if rename:
names = list(field_names)
seen = set()
for i, name in enumerate(names):
if (not all(c.isalnum() or c=='_' for c in name) or _iskeyword(name)
or not name or name[0].isdigit() or name.startswith('_')
or name in seen):
names[i] = '_%d' % i
seen.add(name)
field_names = tuple(names)
for name in (typename,) + field_names:
if not all(c.isalnum() or c=='_' for c in name):
raise ValueError('Type names and field names can only contain alphanumeric characters and underscores: %r' % name)
if _iskeyword(name):
raise ValueError('Type names and field names cannot be a keyword: %r' % name)
if name[0].isdigit():
raise ValueError('Type names and field names cannot start with a number: %r' % name)
seen_names = set()
for name in field_names:
if name.startswith('_') and not rename:
raise ValueError('Field names cannot start with an underscore: %r' % name)
if name in seen_names:
raise ValueError('Encountered duplicate field name: %r' % name)
seen_names.add(name)
# Create and fill-in the class template
numfields = len(field_names)
argtxt = repr(field_names).replace("'", "")[1:-1] # tuple repr without parens or quotes
reprtxt = ', '.join('%s=%%r' % name for name in field_names)
template = '''class %(typename)s(tuple):
'%(typename)s(%(argtxt)s)' \n
__slots__ = () \n
_fields = %(field_names)r \n
def __new__(_cls, %(argtxt)s):
'Create new instance of %(typename)s(%(argtxt)s)'
return _tuple.__new__(_cls, (%(argtxt)s)) \n
@classmethod
def _make(cls, iterable, new=tuple.__new__, len=len):
'Make a new %(typename)s object from a sequence or iterable'
result = new(cls, iterable)
if len(result) != %(numfields)d:
raise TypeError('Expected %(numfields)d arguments, got %%d' %% len(result))
return result \n
def __repr__(self):
'Return a nicely formatted representation string'
return '%(typename)s(%(reprtxt)s)' %% self \n
def _asdict(self):
'Return a new OrderedDict which maps field names to their values'
return OrderedDict(zip(self._fields, self)) \n
def _replace(_self, **kwds):
'Return a new %(typename)s object replacing specified fields with new values'
result = _self._make(map(kwds.pop, %(field_names)r, _self))
if kwds:
raise ValueError('Got unexpected field names: %%r' %% kwds.keys())
return result \n
def __getnewargs__(self):
'Return self as a plain tuple. Used by copy and pickle.'
return tuple(self) \n\n''' % locals()
for i, name in enumerate(field_names):
template += " %s = _property(_itemgetter(%d), doc='Alias for field number %d')\n" % (name, i, i)
if verbose:
print template
# Execute the template string in a temporary namespace and
# support tracing utilities by setting a value for frame.f_globals['__name__']
namespace = dict(_itemgetter=_itemgetter, __name__='namedtuple_%s' % typename,
OrderedDict=OrderedDict, _property=property, _tuple=tuple)
try:
exec template in namespace
except SyntaxError, e:
raise SyntaxError(e.message + ':\n' + template)
result = namespace[typename]
# For pickling to work, the __module__ variable needs to be set to the frame
# where the named tuple is created. Bypass this step in enviroments where
# sys._getframe is not defined (Jython for example) or sys._getframe is not
# defined for arguments greater than 0 (IronPython).
try:
result.__module__ = _sys._getframe(1).f_globals.get('__name__', '__main__')
except (AttributeError, ValueError):
pass
return result
########################################################################
### Counter
########################################################################
class Counter(dict):
'''Dict subclass for counting hashable items. Sometimes called a bag
or multiset. Elements are stored as dictionary keys and their counts
are stored as dictionary values.
>>> c = Counter('abracadabra') # count elements from a string
>>> c.most_common(3) # three most common elements
[('a', 5), ('r', 2), ('b', 2)]
>>> sorted(c) # list all unique elements
['a', 'b', 'c', 'd', 'r']
>>> ''.join(sorted(c.elements())) # list elements with repetitions
'aaaaabbcdrr'
>>> sum(c.values()) # total of all counts
11
>>> c['a'] # count of letter 'a'
5
>>> for elem in 'shazam': # update counts from an iterable
... c[elem] += 1 # by adding 1 to each element's count
>>> c['a'] # now there are seven 'a'
7
>>> del c['r'] # remove all 'r'
>>> c['r'] # now there are zero 'r'
0
>>> d = Counter('simsalabim') # make another counter
>>> c.update(d) # add in the second counter
>>> c['a'] # now there are nine 'a'
9
>>> c.clear() # empty the counter
>>> c
Counter()
Note: If a count is set to zero or reduced to zero, it will remain
in the counter until the entry is deleted or the counter is cleared:
>>> c = Counter('aaabbc')
>>> c['b'] -= 2 # reduce the count of 'b' by two
>>> c.most_common() # 'b' is still in, but its count is zero
[('a', 3), ('c', 1), ('b', 0)]
'''
# References:
# http://en.wikipedia.org/wiki/Multiset
# http://www.gnu.org/software/smalltalk/manual-base/html_node/Bag.html
# http://www.demo2s.com/Tutorial/Cpp/0380__set-multiset/Catalog0380__set-multiset.htm
# http://code.activestate.com/recipes/259174/
# Knuth, TAOCP Vol. II section 4.6.3
def __init__(self, iterable=None, **kwds):
'''Create a new, empty Counter object. And if given, count elements
from an input iterable. Or, initialize the count from another mapping
of elements to their counts.
>>> c = Counter() # a new, empty counter
>>> c = Counter('gallahad') # a new counter from an iterable
>>> c = Counter({'a': 4, 'b': 2}) # a new counter from a mapping
>>> c = Counter(a=4, b=2) # a new counter from keyword args
'''
self.update(iterable, **kwds)
def __missing__(self, key):
'The count of elements not in the Counter is zero.'
# Needed so that self[missing_item] does not raise KeyError
return 0
def most_common(self, n=None):
'''List the n most common elements and their counts from the most
common to the least. If n is None, then list all element counts.
>>> Counter('abracadabra').most_common(3)
[('a', 5), ('r', 2), ('b', 2)]
'''
# Emulate Bag.sortedByCount from Smalltalk
if n is None:
return sorted(self.iteritems(), key=_itemgetter(1), reverse=True)
return _heapq.nlargest(n, self.iteritems(), key=_itemgetter(1))
def elements(self):
'''Iterator over elements repeating each as many times as its count.
>>> c = Counter('ABCABC')
>>> sorted(c.elements())
['A', 'A', 'B', 'B', 'C', 'C']
# Knuth's example for prime factors of 1836: 2**2 * 3**3 * 17**1
>>> prime_factors = Counter({2: 2, 3: 3, 17: 1})
>>> product = 1
>>> for factor in prime_factors.elements(): # loop over factors
... product *= factor # and multiply them
>>> product
1836
Note, if an element's count has been set to zero or is a negative
number, elements() will ignore it.
'''
# Emulate Bag.do from Smalltalk and Multiset.begin from C++.
return _chain.from_iterable(_starmap(_repeat, self.iteritems()))
# Override dict methods where necessary
@classmethod
def fromkeys(cls, iterable, v=None):
# There is no equivalent method for counters because setting v=1
# means that no element can have a count greater than one.
raise NotImplementedError(
'Counter.fromkeys() is undefined. Use Counter(iterable) instead.')
def update(self, iterable=None, **kwds):
'''Like dict.update() but add counts instead of replacing them.
Source can be an iterable, a dictionary, or another Counter instance.
>>> c = Counter('which')
>>> c.update('witch') # add elements from another iterable
>>> d = Counter('watch')
>>> c.update(d) # add elements from another counter
>>> c['h'] # four 'h' in which, witch, and watch
4
'''
# The regular dict.update() operation makes no sense here because the
# replace behavior results in the some of original untouched counts
# being mixed-in with all of the other counts for a mismash that
# doesn't have a straight-forward interpretation in most counting
# contexts. Instead, we implement straight-addition. Both the inputs
# and outputs are allowed to contain zero and negative counts.
if iterable is not None:
if isinstance(iterable, Mapping):
if self:
self_get = self.get
for elem, count in iterable.iteritems():
self[elem] = self_get(elem, 0) + count
else:
dict.update(self, iterable) # fast path when counter is empty
else:
self_get = self.get
for elem in iterable:
self[elem] = self_get(elem, 0) + 1
if kwds:
self.update(kwds)
def subtract(self, iterable=None, **kwds):
'''Like dict.update() but subtracts counts instead of replacing them.
Counts can be reduced below zero. Both the inputs and outputs are
allowed to contain zero and negative counts.
Source can be an iterable, a dictionary, or another Counter instance.
>>> c = Counter('which')
>>> c.subtract('witch') # subtract elements from another iterable
>>> c.subtract(Counter('watch')) # subtract elements from another counter
>>> c['h'] # 2 in which, minus 1 in witch, minus 1 in watch
0
>>> c['w'] # 1 in which, minus 1 in witch, minus 1 in watch
-1
'''
if iterable is not None:
self_get = self.get
if isinstance(iterable, Mapping):
for elem, count in iterable.items():
self[elem] = self_get(elem, 0) - count
else:
for elem in iterable:
self[elem] = self_get(elem, 0) - 1
if kwds:
self.subtract(kwds)
def copy(self):
'Like dict.copy() but returns a Counter instance instead of a dict.'
return Counter(self)
def __delitem__(self, elem):
'Like dict.__delitem__() but does not raise KeyError for missing values.'
if elem in self:
dict.__delitem__(self, elem)
def __repr__(self):
if not self:
return '%s()' % self.__class__.__name__
items = ', '.join(map('%r: %r'.__mod__, self.most_common()))
return '%s({%s})' % (self.__class__.__name__, items)
# Multiset-style mathematical operations discussed in:
# Knuth TAOCP Volume II section 4.6.3 exercise 19
# and at http://en.wikipedia.org/wiki/Multiset
#
# Outputs guaranteed to only include positive counts.
#
# To strip negative and zero counts, add-in an empty counter:
# c += Counter()
def __add__(self, other):
'''Add counts from two counters.
>>> Counter('abbb') + Counter('bcc')
Counter({'b': 4, 'c': 2, 'a': 1})
'''
if not isinstance(other, Counter):
return NotImplemented
result = Counter()
for elem in set(self) | set(other):
newcount = self[elem] + other[elem]
if newcount > 0:
result[elem] = newcount
return result
def __sub__(self, other):
''' Subtract count, but keep only results with positive counts.
>>> Counter('abbbc') - Counter('bccd')
Counter({'b': 2, 'a': 1})
'''
if not isinstance(other, Counter):
return NotImplemented
result = Counter()
for elem in set(self) | set(other):
newcount = self[elem] - other[elem]
if newcount > 0:
result[elem] = newcount
return result
def __or__(self, other):
'''Union is the maximum of value in either of the input counters.
>>> Counter('abbb') | Counter('bcc')
Counter({'b': 3, 'c': 2, 'a': 1})
'''
if not isinstance(other, Counter):
return NotImplemented
result = Counter()
for elem in set(self) | set(other):
p, q = self[elem], other[elem]
newcount = q if p < q else p
if newcount > 0:
result[elem] = newcount
return result
def __and__(self, other):
''' Intersection is the minimum of corresponding counts.
>>> Counter('abbb') & Counter('bcc')
Counter({'b': 1})
'''
if not isinstance(other, Counter):
return NotImplemented
result = Counter()
if len(self) < len(other):
self, other = other, self
for elem in _ifilter(self.__contains__, other):
p, q = self[elem], other[elem]
newcount = p if p < q else q
if newcount > 0:
result[elem] = newcount
return result
if __name__ == '__main__':
# verify that instances can be pickled
from cPickle import loads, dumps
Point = namedtuple('Point', 'x, y', True)
p = Point(x=10, y=20)
assert p == loads(dumps(p))
# test and demonstrate ability to override methods
class Point(namedtuple('Point', 'x y')):
__slots__ = ()
@property
def hypot(self):
return (self.x ** 2 + self.y ** 2) ** 0.5
def __str__(self):
return 'Point: x=%6.3f y=%6.3f hypot=%6.3f' % (self.x, self.y, self.hypot)
for p in Point(3, 4), Point(14, 5/7.):
print p
class Point(namedtuple('Point', 'x y')):
'Point class with optimized _make() and _replace() without error-checking'
__slots__ = ()
_make = classmethod(tuple.__new__)
def _replace(self, _map=map, **kwds):
return self._make(_map(kwds.get, ('x', 'y'), self))
print Point(11, 22)._replace(x=100)
Point3D = namedtuple('Point3D', Point._fields + ('z',))
print Point3D.__doc__
import doctest
TestResults = namedtuple('TestResults', 'failed attempted')
print TestResults(*doctest.testmod())

View File

@@ -0,0 +1,156 @@
"""Conversion functions between RGB and other color systems.
This modules provides two functions for each color system ABC:
rgb_to_abc(r, g, b) --> a, b, c
abc_to_rgb(a, b, c) --> r, g, b
All inputs and outputs are triples of floats in the range [0.0...1.0]
(with the exception of I and Q, which covers a slightly larger range).
Inputs outside the valid range may cause exceptions or invalid outputs.
Supported color systems:
RGB: Red, Green, Blue components
YIQ: Luminance, Chrominance (used by composite video signals)
HLS: Hue, Luminance, Saturation
HSV: Hue, Saturation, Value
"""
# References:
# http://en.wikipedia.org/wiki/YIQ
# http://en.wikipedia.org/wiki/HLS_color_space
# http://en.wikipedia.org/wiki/HSV_color_space
__all__ = ["rgb_to_yiq","yiq_to_rgb","rgb_to_hls","hls_to_rgb",
"rgb_to_hsv","hsv_to_rgb"]
# Some floating point constants
ONE_THIRD = 1.0/3.0
ONE_SIXTH = 1.0/6.0
TWO_THIRD = 2.0/3.0
# YIQ: used by composite video signals (linear combinations of RGB)
# Y: perceived grey level (0.0 == black, 1.0 == white)
# I, Q: color components
def rgb_to_yiq(r, g, b):
y = 0.30*r + 0.59*g + 0.11*b
i = 0.60*r - 0.28*g - 0.32*b
q = 0.21*r - 0.52*g + 0.31*b
return (y, i, q)
def yiq_to_rgb(y, i, q):
r = y + 0.948262*i + 0.624013*q
g = y - 0.276066*i - 0.639810*q
b = y - 1.105450*i + 1.729860*q
if r < 0.0:
r = 0.0
if g < 0.0:
g = 0.0
if b < 0.0:
b = 0.0
if r > 1.0:
r = 1.0
if g > 1.0:
g = 1.0
if b > 1.0:
b = 1.0
return (r, g, b)
# HLS: Hue, Luminance, Saturation
# H: position in the spectrum
# L: color lightness
# S: color saturation
def rgb_to_hls(r, g, b):
maxc = max(r, g, b)
minc = min(r, g, b)
# XXX Can optimize (maxc+minc) and (maxc-minc)
l = (minc+maxc)/2.0
if minc == maxc:
return 0.0, l, 0.0
if l <= 0.5:
s = (maxc-minc) / (maxc+minc)
else:
s = (maxc-minc) / (2.0-maxc-minc)
rc = (maxc-r) / (maxc-minc)
gc = (maxc-g) / (maxc-minc)
bc = (maxc-b) / (maxc-minc)
if r == maxc:
h = bc-gc
elif g == maxc:
h = 2.0+rc-bc
else:
h = 4.0+gc-rc
h = (h/6.0) % 1.0
return h, l, s
def hls_to_rgb(h, l, s):
if s == 0.0:
return l, l, l
if l <= 0.5:
m2 = l * (1.0+s)
else:
m2 = l+s-(l*s)
m1 = 2.0*l - m2
return (_v(m1, m2, h+ONE_THIRD), _v(m1, m2, h), _v(m1, m2, h-ONE_THIRD))
def _v(m1, m2, hue):
hue = hue % 1.0
if hue < ONE_SIXTH:
return m1 + (m2-m1)*hue*6.0
if hue < 0.5:
return m2
if hue < TWO_THIRD:
return m1 + (m2-m1)*(TWO_THIRD-hue)*6.0
return m1
# HSV: Hue, Saturation, Value
# H: position in the spectrum
# S: color saturation ("purity")
# V: color brightness
def rgb_to_hsv(r, g, b):
maxc = max(r, g, b)
minc = min(r, g, b)
v = maxc
if minc == maxc:
return 0.0, 0.0, v
s = (maxc-minc) / maxc
rc = (maxc-r) / (maxc-minc)
gc = (maxc-g) / (maxc-minc)
bc = (maxc-b) / (maxc-minc)
if r == maxc:
h = bc-gc
elif g == maxc:
h = 2.0+rc-bc
else:
h = 4.0+gc-rc
h = (h/6.0) % 1.0
return h, s, v
def hsv_to_rgb(h, s, v):
if s == 0.0:
return v, v, v
i = int(h*6.0) # XXX assume int() truncates!
f = (h*6.0) - i
p = v*(1.0 - s)
q = v*(1.0 - s*f)
t = v*(1.0 - s*(1.0-f))
i = i%6
if i == 0:
return v, t, p
if i == 1:
return q, v, p
if i == 2:
return p, v, t
if i == 3:
return p, q, v
if i == 4:
return t, p, v
if i == 5:
return v, p, q
# Cannot get here

View File

@@ -0,0 +1,90 @@
"""Execute shell commands via os.popen() and return status, output.
Interface summary:
import commands
outtext = commands.getoutput(cmd)
(exitstatus, outtext) = commands.getstatusoutput(cmd)
outtext = commands.getstatus(file) # returns output of "ls -ld file"
A trailing newline is removed from the output string.
Encapsulates the basic operation:
pipe = os.popen('{ ' + cmd + '; } 2>&1', 'r')
text = pipe.read()
sts = pipe.close()
[Note: it would be nice to add functions to interpret the exit status.]
"""
from warnings import warnpy3k
warnpy3k("the commands module has been removed in Python 3.0; "
"use the subprocess module instead", stacklevel=2)
del warnpy3k
__all__ = ["getstatusoutput","getoutput","getstatus"]
# Module 'commands'
#
# Various tools for executing commands and looking at their output and status.
#
# NB This only works (and is only relevant) for UNIX.
# Get 'ls -l' status for an object into a string
#
def getstatus(file):
"""Return output of "ls -ld <file>" in a string."""
import warnings
warnings.warn("commands.getstatus() is deprecated", DeprecationWarning, 2)
return getoutput('ls -ld' + mkarg(file))
# Get the output from a shell command into a string.
# The exit status is ignored; a trailing newline is stripped.
# Assume the command will work with '{ ... ; } 2>&1' around it..
#
def getoutput(cmd):
"""Return output (stdout or stderr) of executing cmd in a shell."""
return getstatusoutput(cmd)[1]
# Ditto but preserving the exit status.
# Returns a pair (sts, output)
#
def getstatusoutput(cmd):
"""Return (status, output) of executing cmd in a shell."""
import os
pipe = os.popen('{ ' + cmd + '; } 2>&1', 'r')
text = pipe.read()
sts = pipe.close()
if sts is None: sts = 0
if text[-1:] == '\n': text = text[:-1]
return sts, text
# Make command argument from directory and pathname (prefix space, add quotes).
#
def mk2arg(head, x):
import os
return mkarg(os.path.join(head, x))
# Make a shell command argument from a string.
# Return a string beginning with a space followed by a shell-quoted
# version of the argument.
# Two strategies: enclose in single quotes if it contains none;
# otherwise, enclose in double quotes and prefix quotable characters
# with backslash.
#
def mkarg(x):
if '\'' not in x:
return ' \'' + x + '\''
s = ' "'
for c in x:
if c in '\\$"`':
s = s + '\\'
s = s + c
s = s + '"'
return s

View File

@@ -0,0 +1,216 @@
"""Module/script to "compile" all .py files to .pyc (or .pyo) file.
When called as a script with arguments, this compiles the directories
given as arguments recursively; the -l option prevents it from
recursing into directories.
Without arguments, if compiles all modules on sys.path, without
recursing into subdirectories. (Even though it should do so for
packages -- for now, you'll have to deal with packages separately.)
See module py_compile for details of the actual byte-compilation.
"""
import os
import sys
import py_compile
import struct
import imp
__all__ = ["compile_dir","compile_file","compile_path"]
def compile_dir(dir, maxlevels=10, ddir=None,
force=0, rx=None, quiet=0):
"""Byte-compile all modules in the given directory tree.
Arguments (only dir is required):
dir: the directory to byte-compile
maxlevels: maximum recursion level (default 10)
ddir: if given, purported directory name (this is the
directory name that will show up in error messages)
force: if 1, force compilation, even if timestamps are up-to-date
quiet: if 1, be quiet during compilation
"""
if not quiet:
print 'Listing', dir, '...'
try:
names = os.listdir(dir)
except os.error:
print "Can't list", dir
names = []
names.sort()
success = 1
for name in names:
fullname = os.path.join(dir, name)
if ddir is not None:
dfile = os.path.join(ddir, name)
else:
dfile = None
if not os.path.isdir(fullname):
if not compile_file(fullname, ddir, force, rx, quiet):
success = 0
elif maxlevels > 0 and \
name != os.curdir and name != os.pardir and \
os.path.isdir(fullname) and \
not os.path.islink(fullname):
if not compile_dir(fullname, maxlevels - 1, dfile, force, rx,
quiet):
success = 0
return success
def compile_file(fullname, ddir=None, force=0, rx=None, quiet=0):
"""Byte-compile file.
file: the file to byte-compile
ddir: if given, purported directory name (this is the
directory name that will show up in error messages)
force: if 1, force compilation, even if timestamps are up-to-date
quiet: if 1, be quiet during compilation
"""
success = 1
name = os.path.basename(fullname)
if ddir is not None:
dfile = os.path.join(ddir, name)
else:
dfile = None
if rx is not None:
mo = rx.search(fullname)
if mo:
return success
if os.path.isfile(fullname):
head, tail = name[:-3], name[-3:]
if tail == '.py':
if not force:
try:
mtime = int(os.stat(fullname).st_mtime)
expect = struct.pack('<4sl', imp.get_magic(), mtime)
cfile = fullname + (__debug__ and 'c' or 'o')
with open(cfile, 'rb') as chandle:
actual = chandle.read(8)
if expect == actual:
return success
except IOError:
pass
if not quiet:
print 'Compiling', fullname, '...'
try:
ok = py_compile.compile(fullname, None, dfile, True)
except py_compile.PyCompileError,err:
if quiet:
print 'Compiling', fullname, '...'
print err.msg
success = 0
except IOError, e:
print "Sorry", e
success = 0
else:
if ok == 0:
success = 0
return success
def compile_path(skip_curdir=1, maxlevels=0, force=0, quiet=0):
"""Byte-compile all module on sys.path.
Arguments (all optional):
skip_curdir: if true, skip current directory (default true)
maxlevels: max recursion level (default 0)
force: as for compile_dir() (default 0)
quiet: as for compile_dir() (default 0)
"""
success = 1
for dir in sys.path:
if (not dir or dir == os.curdir) and skip_curdir:
print 'Skipping current directory'
else:
success = success and compile_dir(dir, maxlevels, None,
force, quiet=quiet)
return success
def expand_args(args, flist):
"""read names in flist and append to args"""
expanded = args[:]
if flist:
try:
if flist == '-':
fd = sys.stdin
else:
fd = open(flist)
while 1:
line = fd.readline()
if not line:
break
expanded.append(line[:-1])
except IOError:
print "Error reading file list %s" % flist
raise
return expanded
def main():
"""Script main program."""
import getopt
try:
opts, args = getopt.getopt(sys.argv[1:], 'lfqd:x:i:')
except getopt.error, msg:
print msg
print "usage: python compileall.py [-l] [-f] [-q] [-d destdir] " \
"[-x regexp] [-i list] [directory|file ...]"
print "-l: don't recurse down"
print "-f: force rebuild even if timestamps are up-to-date"
print "-q: quiet operation"
print "-d destdir: purported directory name for error messages"
print " if no directory arguments, -l sys.path is assumed"
print "-x regexp: skip files matching the regular expression regexp"
print " the regexp is searched for in the full path of the file"
print "-i list: expand list with its content (file and directory names)"
sys.exit(2)
maxlevels = 10
ddir = None
force = 0
quiet = 0
rx = None
flist = None
for o, a in opts:
if o == '-l': maxlevels = 0
if o == '-d': ddir = a
if o == '-f': force = 1
if o == '-q': quiet = 1
if o == '-x':
import re
rx = re.compile(a)
if o == '-i': flist = a
if ddir:
if len(args) != 1 and not os.path.isdir(args[0]):
print "-d destdir require exactly one directory argument"
sys.exit(2)
success = 1
try:
if args or flist:
try:
if flist:
args = expand_args(args, flist)
except IOError:
success = 0
if success:
for arg in args:
if os.path.isdir(arg):
if not compile_dir(arg, maxlevels, ddir,
force, rx, quiet):
success = 0
else:
if not compile_file(arg, ddir, force, rx, quiet):
success = 0
else:
success = compile_path()
except KeyboardInterrupt:
print "\n[interrupt]"
success = 0
return success
if __name__ == '__main__':
exit_status = int(not main())
sys.exit(exit_status)

View File

@@ -0,0 +1,31 @@
"""Package for parsing and compiling Python source code
There are several functions defined at the top level that are imported
from modules contained in the package.
parse(buf, mode="exec") -> AST
Converts a string containing Python source code to an abstract
syntax tree (AST). The AST is defined in compiler.ast.
parseFile(path) -> AST
The same as parse(open(path))
walk(ast, visitor, verbose=None)
Does a pre-order walk over the ast using the visitor instance.
See compiler.visitor for details.
compile(source, filename, mode, flags=None, dont_inherit=None)
Returns a code object. A replacement for the builtin compile() function.
compileFile(filename)
Generates a .pyc file by compiling filename.
"""
import warnings
warnings.warn("The compiler package is deprecated and removed in Python 3.x.",
DeprecationWarning, stacklevel=2)
from compiler.transformer import parse, parseFile
from compiler.visitor import walk
from compiler.pycodegen import compile, compileFile

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,23 @@
# operation flags
OP_ASSIGN = 'OP_ASSIGN'
OP_DELETE = 'OP_DELETE'
OP_APPLY = 'OP_APPLY'
SC_LOCAL = 1
SC_GLOBAL_IMPLICIT = 2
SC_GLOBAL_EXPLICT = 3
SC_FREE = 4
SC_CELL = 5
SC_UNKNOWN = 6
CO_OPTIMIZED = 0x0001
CO_NEWLOCALS = 0x0002
CO_VARARGS = 0x0004
CO_VARKEYWORDS = 0x0008
CO_NESTED = 0x0010
CO_GENERATOR = 0x0020
CO_GENERATOR_ALLOWED = 0
CO_FUTURE_DIVISION = 0x2000
CO_FUTURE_ABSIMPORT = 0x4000
CO_FUTURE_WITH_STATEMENT = 0x8000
CO_FUTURE_PRINT_FUNCTION = 0x10000

View File

@@ -0,0 +1,74 @@
"""Parser for future statements
"""
from compiler import ast, walk
def is_future(stmt):
"""Return true if statement is a well-formed future statement"""
if not isinstance(stmt, ast.From):
return 0
if stmt.modname == "__future__":
return 1
else:
return 0
class FutureParser:
features = ("nested_scopes", "generators", "division",
"absolute_import", "with_statement", "print_function",
"unicode_literals")
def __init__(self):
self.found = {} # set
def visitModule(self, node):
stmt = node.node
for s in stmt.nodes:
if not self.check_stmt(s):
break
def check_stmt(self, stmt):
if is_future(stmt):
for name, asname in stmt.names:
if name in self.features:
self.found[name] = 1
else:
raise SyntaxError, \
"future feature %s is not defined" % name
stmt.valid_future = 1
return 1
return 0
def get_features(self):
"""Return list of features enabled by future statements"""
return self.found.keys()
class BadFutureParser:
"""Check for invalid future statements"""
def visitFrom(self, node):
if hasattr(node, 'valid_future'):
return
if node.modname != "__future__":
return
raise SyntaxError, "invalid future statement " + repr(node)
def find_futures(node):
p1 = FutureParser()
p2 = BadFutureParser()
walk(node, p1)
walk(node, p2)
return p1.get_features()
if __name__ == "__main__":
import sys
from compiler import parseFile, walk
for file in sys.argv[1:]:
print file
tree = parseFile(file)
v = FutureParser()
walk(tree, v)
print v.found
print

View File

@@ -0,0 +1,73 @@
def flatten(tup):
elts = []
for elt in tup:
if isinstance(elt, tuple):
elts = elts + flatten(elt)
else:
elts.append(elt)
return elts
class Set:
def __init__(self):
self.elts = {}
def __len__(self):
return len(self.elts)
def __contains__(self, elt):
return elt in self.elts
def add(self, elt):
self.elts[elt] = elt
def elements(self):
return self.elts.keys()
def has_elt(self, elt):
return elt in self.elts
def remove(self, elt):
del self.elts[elt]
def copy(self):
c = Set()
c.elts.update(self.elts)
return c
class Stack:
def __init__(self):
self.stack = []
self.pop = self.stack.pop
def __len__(self):
return len(self.stack)
def push(self, elt):
self.stack.append(elt)
def top(self):
return self.stack[-1]
def __getitem__(self, index): # needed by visitContinue()
return self.stack[index]
MANGLE_LEN = 256 # magic constant from compile.c
def mangle(name, klass):
if not name.startswith('__'):
return name
if len(name) + 2 >= MANGLE_LEN:
return name
if name.endswith('__'):
return name
try:
i = 0
while klass[i] == '_':
i = i + 1
except IndexError:
return name
klass = klass[i:]
tlen = len(klass) + len(name)
if tlen > MANGLE_LEN:
klass = klass[:MANGLE_LEN-tlen]
return "_%s%s" % (klass, name)
def set_filename(filename, tree):
"""Set the filename attribute to filename on every node in tree"""
worklist = [tree]
while worklist:
node = worklist.pop(0)
node.filename = filename
worklist.extend(node.getChildNodes())

View File

@@ -0,0 +1,763 @@
"""A flow graph representation for Python bytecode"""
import dis
import types
import sys
from compiler import misc
from compiler.consts \
import CO_OPTIMIZED, CO_NEWLOCALS, CO_VARARGS, CO_VARKEYWORDS
class FlowGraph:
def __init__(self):
self.current = self.entry = Block()
self.exit = Block("exit")
self.blocks = misc.Set()
self.blocks.add(self.entry)
self.blocks.add(self.exit)
def startBlock(self, block):
if self._debug:
if self.current:
print "end", repr(self.current)
print " next", self.current.next
print " prev", self.current.prev
print " ", self.current.get_children()
print repr(block)
self.current = block
def nextBlock(self, block=None):
# XXX think we need to specify when there is implicit transfer
# from one block to the next. might be better to represent this
# with explicit JUMP_ABSOLUTE instructions that are optimized
# out when they are unnecessary.
#
# I think this strategy works: each block has a child
# designated as "next" which is returned as the last of the
# children. because the nodes in a graph are emitted in
# reverse post order, the "next" block will always be emitted
# immediately after its parent.
# Worry: maintaining this invariant could be tricky
if block is None:
block = self.newBlock()
# Note: If the current block ends with an unconditional control
# transfer, then it is techically incorrect to add an implicit
# transfer to the block graph. Doing so results in code generation
# for unreachable blocks. That doesn't appear to be very common
# with Python code and since the built-in compiler doesn't optimize
# it out we don't either.
self.current.addNext(block)
self.startBlock(block)
def newBlock(self):
b = Block()
self.blocks.add(b)
return b
def startExitBlock(self):
self.startBlock(self.exit)
_debug = 0
def _enable_debug(self):
self._debug = 1
def _disable_debug(self):
self._debug = 0
def emit(self, *inst):
if self._debug:
print "\t", inst
if len(inst) == 2 and isinstance(inst[1], Block):
self.current.addOutEdge(inst[1])
self.current.emit(inst)
def getBlocksInOrder(self):
"""Return the blocks in reverse postorder
i.e. each node appears before all of its successors
"""
order = order_blocks(self.entry, self.exit)
return order
def getBlocks(self):
return self.blocks.elements()
def getRoot(self):
"""Return nodes appropriate for use with dominator"""
return self.entry
def getContainedGraphs(self):
l = []
for b in self.getBlocks():
l.extend(b.getContainedGraphs())
return l
def order_blocks(start_block, exit_block):
"""Order blocks so that they are emitted in the right order"""
# Rules:
# - when a block has a next block, the next block must be emitted just after
# - when a block has followers (relative jumps), it must be emitted before
# them
# - all reachable blocks must be emitted
order = []
# Find all the blocks to be emitted.
remaining = set()
todo = [start_block]
while todo:
b = todo.pop()
if b in remaining:
continue
remaining.add(b)
for c in b.get_children():
if c not in remaining:
todo.append(c)
# A block is dominated by another block if that block must be emitted
# before it.
dominators = {}
for b in remaining:
if __debug__ and b.next:
assert b is b.next[0].prev[0], (b, b.next)
# Make sure every block appears in dominators, even if no
# other block must precede it.
dominators.setdefault(b, set())
# preceeding blocks dominate following blocks
for c in b.get_followers():
while 1:
dominators.setdefault(c, set()).add(b)
# Any block that has a next pointer leading to c is also
# dominated because the whole chain will be emitted at once.
# Walk backwards and add them all.
if c.prev and c.prev[0] is not b:
c = c.prev[0]
else:
break
def find_next():
# Find a block that can be emitted next.
for b in remaining:
for c in dominators[b]:
if c in remaining:
break # can't emit yet, dominated by a remaining block
else:
return b
assert 0, 'circular dependency, cannot find next block'
b = start_block
while 1:
order.append(b)
remaining.discard(b)
if b.next:
b = b.next[0]
continue
elif b is not exit_block and not b.has_unconditional_transfer():
order.append(exit_block)
if not remaining:
break
b = find_next()
return order
class Block:
_count = 0
def __init__(self, label=''):
self.insts = []
self.outEdges = set()
self.label = label
self.bid = Block._count
self.next = []
self.prev = []
Block._count = Block._count + 1
def __repr__(self):
if self.label:
return "<block %s id=%d>" % (self.label, self.bid)
else:
return "<block id=%d>" % (self.bid)
def __str__(self):
insts = map(str, self.insts)
return "<block %s %d:\n%s>" % (self.label, self.bid,
'\n'.join(insts))
def emit(self, inst):
op = inst[0]
self.insts.append(inst)
def getInstructions(self):
return self.insts
def addOutEdge(self, block):
self.outEdges.add(block)
def addNext(self, block):
self.next.append(block)
assert len(self.next) == 1, map(str, self.next)
block.prev.append(self)
assert len(block.prev) == 1, map(str, block.prev)
_uncond_transfer = ('RETURN_VALUE', 'RAISE_VARARGS',
'JUMP_ABSOLUTE', 'JUMP_FORWARD', 'CONTINUE_LOOP',
)
def has_unconditional_transfer(self):
"""Returns True if there is an unconditional transfer to an other block
at the end of this block. This means there is no risk for the bytecode
executer to go past this block's bytecode."""
try:
op, arg = self.insts[-1]
except (IndexError, ValueError):
return
return op in self._uncond_transfer
def get_children(self):
return list(self.outEdges) + self.next
def get_followers(self):
"""Get the whole list of followers, including the next block."""
followers = set(self.next)
# Blocks that must be emitted *after* this one, because of
# bytecode offsets (e.g. relative jumps) pointing to them.
for inst in self.insts:
if inst[0] in PyFlowGraph.hasjrel:
followers.add(inst[1])
return followers
def getContainedGraphs(self):
"""Return all graphs contained within this block.
For example, a MAKE_FUNCTION block will contain a reference to
the graph for the function body.
"""
contained = []
for inst in self.insts:
if len(inst) == 1:
continue
op = inst[1]
if hasattr(op, 'graph'):
contained.append(op.graph)
return contained
# flags for code objects
# the FlowGraph is transformed in place; it exists in one of these states
RAW = "RAW"
FLAT = "FLAT"
CONV = "CONV"
DONE = "DONE"
class PyFlowGraph(FlowGraph):
super_init = FlowGraph.__init__
def __init__(self, name, filename, args=(), optimized=0, klass=None):
self.super_init()
self.name = name
self.filename = filename
self.docstring = None
self.args = args # XXX
self.argcount = getArgCount(args)
self.klass = klass
if optimized:
self.flags = CO_OPTIMIZED | CO_NEWLOCALS
else:
self.flags = 0
self.consts = []
self.names = []
# Free variables found by the symbol table scan, including
# variables used only in nested scopes, are included here.
self.freevars = []
self.cellvars = []
# The closure list is used to track the order of cell
# variables and free variables in the resulting code object.
# The offsets used by LOAD_CLOSURE/LOAD_DEREF refer to both
# kinds of variables.
self.closure = []
self.varnames = list(args) or []
for i in range(len(self.varnames)):
var = self.varnames[i]
if isinstance(var, TupleArg):
self.varnames[i] = var.getName()
self.stage = RAW
def setDocstring(self, doc):
self.docstring = doc
def setFlag(self, flag):
self.flags = self.flags | flag
if flag == CO_VARARGS:
self.argcount = self.argcount - 1
def checkFlag(self, flag):
if self.flags & flag:
return 1
def setFreeVars(self, names):
self.freevars = list(names)
def setCellVars(self, names):
self.cellvars = names
def getCode(self):
"""Get a Python code object"""
assert self.stage == RAW
self.computeStackDepth()
self.flattenGraph()
assert self.stage == FLAT
self.convertArgs()
assert self.stage == CONV
self.makeByteCode()
assert self.stage == DONE
return self.newCodeObject()
def dump(self, io=None):
if io:
save = sys.stdout
sys.stdout = io
pc = 0
for t in self.insts:
opname = t[0]
if opname == "SET_LINENO":
print
if len(t) == 1:
print "\t", "%3d" % pc, opname
pc = pc + 1
else:
print "\t", "%3d" % pc, opname, t[1]
pc = pc + 3
if io:
sys.stdout = save
def computeStackDepth(self):
"""Compute the max stack depth.
Approach is to compute the stack effect of each basic block.
Then find the path through the code with the largest total
effect.
"""
depth = {}
exit = None
for b in self.getBlocks():
depth[b] = findDepth(b.getInstructions())
seen = {}
def max_depth(b, d):
if b in seen:
return d
seen[b] = 1
d = d + depth[b]
children = b.get_children()
if children:
return max([max_depth(c, d) for c in children])
else:
if not b.label == "exit":
return max_depth(self.exit, d)
else:
return d
self.stacksize = max_depth(self.entry, 0)
def flattenGraph(self):
"""Arrange the blocks in order and resolve jumps"""
assert self.stage == RAW
self.insts = insts = []
pc = 0
begin = {}
end = {}
for b in self.getBlocksInOrder():
begin[b] = pc
for inst in b.getInstructions():
insts.append(inst)
if len(inst) == 1:
pc = pc + 1
elif inst[0] != "SET_LINENO":
# arg takes 2 bytes
pc = pc + 3
end[b] = pc
pc = 0
for i in range(len(insts)):
inst = insts[i]
if len(inst) == 1:
pc = pc + 1
elif inst[0] != "SET_LINENO":
pc = pc + 3
opname = inst[0]
if opname in self.hasjrel:
oparg = inst[1]
offset = begin[oparg] - pc
insts[i] = opname, offset
elif opname in self.hasjabs:
insts[i] = opname, begin[inst[1]]
self.stage = FLAT
hasjrel = set()
for i in dis.hasjrel:
hasjrel.add(dis.opname[i])
hasjabs = set()
for i in dis.hasjabs:
hasjabs.add(dis.opname[i])
def convertArgs(self):
"""Convert arguments from symbolic to concrete form"""
assert self.stage == FLAT
self.consts.insert(0, self.docstring)
self.sort_cellvars()
for i in range(len(self.insts)):
t = self.insts[i]
if len(t) == 2:
opname, oparg = t
conv = self._converters.get(opname, None)
if conv:
self.insts[i] = opname, conv(self, oparg)
self.stage = CONV
def sort_cellvars(self):
"""Sort cellvars in the order of varnames and prune from freevars.
"""
cells = {}
for name in self.cellvars:
cells[name] = 1
self.cellvars = [name for name in self.varnames
if name in cells]
for name in self.cellvars:
del cells[name]
self.cellvars = self.cellvars + cells.keys()
self.closure = self.cellvars + self.freevars
def _lookupName(self, name, list):
"""Return index of name in list, appending if necessary
This routine uses a list instead of a dictionary, because a
dictionary can't store two different keys if the keys have the
same value but different types, e.g. 2 and 2L. The compiler
must treat these two separately, so it does an explicit type
comparison before comparing the values.
"""
t = type(name)
for i in range(len(list)):
if t == type(list[i]) and list[i] == name:
return i
end = len(list)
list.append(name)
return end
_converters = {}
def _convert_LOAD_CONST(self, arg):
if hasattr(arg, 'getCode'):
arg = arg.getCode()
return self._lookupName(arg, self.consts)
def _convert_LOAD_FAST(self, arg):
self._lookupName(arg, self.names)
return self._lookupName(arg, self.varnames)
_convert_STORE_FAST = _convert_LOAD_FAST
_convert_DELETE_FAST = _convert_LOAD_FAST
def _convert_LOAD_NAME(self, arg):
if self.klass is None:
self._lookupName(arg, self.varnames)
return self._lookupName(arg, self.names)
def _convert_NAME(self, arg):
if self.klass is None:
self._lookupName(arg, self.varnames)
return self._lookupName(arg, self.names)
_convert_STORE_NAME = _convert_NAME
_convert_DELETE_NAME = _convert_NAME
_convert_IMPORT_NAME = _convert_NAME
_convert_IMPORT_FROM = _convert_NAME
_convert_STORE_ATTR = _convert_NAME
_convert_LOAD_ATTR = _convert_NAME
_convert_DELETE_ATTR = _convert_NAME
_convert_LOAD_GLOBAL = _convert_NAME
_convert_STORE_GLOBAL = _convert_NAME
_convert_DELETE_GLOBAL = _convert_NAME
def _convert_DEREF(self, arg):
self._lookupName(arg, self.names)
self._lookupName(arg, self.varnames)
return self._lookupName(arg, self.closure)
_convert_LOAD_DEREF = _convert_DEREF
_convert_STORE_DEREF = _convert_DEREF
def _convert_LOAD_CLOSURE(self, arg):
self._lookupName(arg, self.varnames)
return self._lookupName(arg, self.closure)
_cmp = list(dis.cmp_op)
def _convert_COMPARE_OP(self, arg):
return self._cmp.index(arg)
# similarly for other opcodes...
for name, obj in locals().items():
if name[:9] == "_convert_":
opname = name[9:]
_converters[opname] = obj
del name, obj, opname
def makeByteCode(self):
assert self.stage == CONV
self.lnotab = lnotab = LineAddrTable()
for t in self.insts:
opname = t[0]
if len(t) == 1:
lnotab.addCode(self.opnum[opname])
else:
oparg = t[1]
if opname == "SET_LINENO":
lnotab.nextLine(oparg)
continue
hi, lo = twobyte(oparg)
try:
lnotab.addCode(self.opnum[opname], lo, hi)
except ValueError:
print opname, oparg
print self.opnum[opname], lo, hi
raise
self.stage = DONE
opnum = {}
for num in range(len(dis.opname)):
opnum[dis.opname[num]] = num
del num
def newCodeObject(self):
assert self.stage == DONE
if (self.flags & CO_NEWLOCALS) == 0:
nlocals = 0
else:
nlocals = len(self.varnames)
argcount = self.argcount
if self.flags & CO_VARKEYWORDS:
argcount = argcount - 1
return types.CodeType(argcount, nlocals, self.stacksize, self.flags,
self.lnotab.getCode(), self.getConsts(),
tuple(self.names), tuple(self.varnames),
self.filename, self.name, self.lnotab.firstline,
self.lnotab.getTable(), tuple(self.freevars),
tuple(self.cellvars))
def getConsts(self):
"""Return a tuple for the const slot of the code object
Must convert references to code (MAKE_FUNCTION) to code
objects recursively.
"""
l = []
for elt in self.consts:
if isinstance(elt, PyFlowGraph):
elt = elt.getCode()
l.append(elt)
return tuple(l)
def isJump(opname):
if opname[:4] == 'JUMP':
return 1
class TupleArg:
"""Helper for marking func defs with nested tuples in arglist"""
def __init__(self, count, names):
self.count = count
self.names = names
def __repr__(self):
return "TupleArg(%s, %s)" % (self.count, self.names)
def getName(self):
return ".%d" % self.count
def getArgCount(args):
argcount = len(args)
if args:
for arg in args:
if isinstance(arg, TupleArg):
numNames = len(misc.flatten(arg.names))
argcount = argcount - numNames
return argcount
def twobyte(val):
"""Convert an int argument into high and low bytes"""
assert isinstance(val, int)
return divmod(val, 256)
class LineAddrTable:
"""lnotab
This class builds the lnotab, which is documented in compile.c.
Here's a brief recap:
For each SET_LINENO instruction after the first one, two bytes are
added to lnotab. (In some cases, multiple two-byte entries are
added.) The first byte is the distance in bytes between the
instruction for the last SET_LINENO and the current SET_LINENO.
The second byte is offset in line numbers. If either offset is
greater than 255, multiple two-byte entries are added -- see
compile.c for the delicate details.
"""
def __init__(self):
self.code = []
self.codeOffset = 0
self.firstline = 0
self.lastline = 0
self.lastoff = 0
self.lnotab = []
def addCode(self, *args):
for arg in args:
self.code.append(chr(arg))
self.codeOffset = self.codeOffset + len(args)
def nextLine(self, lineno):
if self.firstline == 0:
self.firstline = lineno
self.lastline = lineno
else:
# compute deltas
addr = self.codeOffset - self.lastoff
line = lineno - self.lastline
# Python assumes that lineno always increases with
# increasing bytecode address (lnotab is unsigned char).
# Depending on when SET_LINENO instructions are emitted
# this is not always true. Consider the code:
# a = (1,
# b)
# In the bytecode stream, the assignment to "a" occurs
# after the loading of "b". This works with the C Python
# compiler because it only generates a SET_LINENO instruction
# for the assignment.
if line >= 0:
push = self.lnotab.append
while addr > 255:
push(255); push(0)
addr -= 255
while line > 255:
push(addr); push(255)
line -= 255
addr = 0
if addr > 0 or line > 0:
push(addr); push(line)
self.lastline = lineno
self.lastoff = self.codeOffset
def getCode(self):
return ''.join(self.code)
def getTable(self):
return ''.join(map(chr, self.lnotab))
class StackDepthTracker:
# XXX 1. need to keep track of stack depth on jumps
# XXX 2. at least partly as a result, this code is broken
def findDepth(self, insts, debug=0):
depth = 0
maxDepth = 0
for i in insts:
opname = i[0]
if debug:
print i,
delta = self.effect.get(opname, None)
if delta is not None:
depth = depth + delta
else:
# now check patterns
for pat, pat_delta in self.patterns:
if opname[:len(pat)] == pat:
delta = pat_delta
depth = depth + delta
break
# if we still haven't found a match
if delta is None:
meth = getattr(self, opname, None)
if meth is not None:
depth = depth + meth(i[1])
if depth > maxDepth:
maxDepth = depth
if debug:
print depth, maxDepth
return maxDepth
effect = {
'POP_TOP': -1,
'DUP_TOP': 1,
'LIST_APPEND': -1,
'SET_ADD': -1,
'MAP_ADD': -2,
'SLICE+1': -1,
'SLICE+2': -1,
'SLICE+3': -2,
'STORE_SLICE+0': -1,
'STORE_SLICE+1': -2,
'STORE_SLICE+2': -2,
'STORE_SLICE+3': -3,
'DELETE_SLICE+0': -1,
'DELETE_SLICE+1': -2,
'DELETE_SLICE+2': -2,
'DELETE_SLICE+3': -3,
'STORE_SUBSCR': -3,
'DELETE_SUBSCR': -2,
# PRINT_EXPR?
'PRINT_ITEM': -1,
'RETURN_VALUE': -1,
'YIELD_VALUE': -1,
'EXEC_STMT': -3,
'BUILD_CLASS': -2,
'STORE_NAME': -1,
'STORE_ATTR': -2,
'DELETE_ATTR': -1,
'STORE_GLOBAL': -1,
'BUILD_MAP': 1,
'COMPARE_OP': -1,
'STORE_FAST': -1,
'IMPORT_STAR': -1,
'IMPORT_NAME': -1,
'IMPORT_FROM': 1,
'LOAD_ATTR': 0, # unlike other loads
# close enough...
'SETUP_EXCEPT': 3,
'SETUP_FINALLY': 3,
'FOR_ITER': 1,
'WITH_CLEANUP': -1,
}
# use pattern match
patterns = [
('BINARY_', -1),
('LOAD_', 1),
]
def UNPACK_SEQUENCE(self, count):
return count-1
def BUILD_TUPLE(self, count):
return -count+1
def BUILD_LIST(self, count):
return -count+1
def BUILD_SET(self, count):
return -count+1
def CALL_FUNCTION(self, argc):
hi, lo = divmod(argc, 256)
return -(lo + hi * 2)
def CALL_FUNCTION_VAR(self, argc):
return self.CALL_FUNCTION(argc)-1
def CALL_FUNCTION_KW(self, argc):
return self.CALL_FUNCTION(argc)-1
def CALL_FUNCTION_VAR_KW(self, argc):
return self.CALL_FUNCTION(argc)-2
def MAKE_FUNCTION(self, argc):
return -argc
def MAKE_CLOSURE(self, argc):
# XXX need to account for free variables too!
return -argc
def BUILD_SLICE(self, argc):
if argc == 2:
return -1
elif argc == 3:
return -2
def DUP_TOPX(self, argc):
return argc
findDepth = StackDepthTracker().findDepth

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,462 @@
"""Module symbol-table generator"""
from compiler import ast
from compiler.consts import SC_LOCAL, SC_GLOBAL_IMPLICIT, SC_GLOBAL_EXPLICT, \
SC_FREE, SC_CELL, SC_UNKNOWN
from compiler.misc import mangle
import types
import sys
MANGLE_LEN = 256
class Scope:
# XXX how much information do I need about each name?
def __init__(self, name, module, klass=None):
self.name = name
self.module = module
self.defs = {}
self.uses = {}
self.globals = {}
self.params = {}
self.frees = {}
self.cells = {}
self.children = []
# nested is true if the class could contain free variables,
# i.e. if it is nested within another function.
self.nested = None
self.generator = None
self.klass = None
if klass is not None:
for i in range(len(klass)):
if klass[i] != '_':
self.klass = klass[i:]
break
def __repr__(self):
return "<%s: %s>" % (self.__class__.__name__, self.name)
def mangle(self, name):
if self.klass is None:
return name
return mangle(name, self.klass)
def add_def(self, name):
self.defs[self.mangle(name)] = 1
def add_use(self, name):
self.uses[self.mangle(name)] = 1
def add_global(self, name):
name = self.mangle(name)
if name in self.uses or name in self.defs:
pass # XXX warn about global following def/use
if name in self.params:
raise SyntaxError, "%s in %s is global and parameter" % \
(name, self.name)
self.globals[name] = 1
self.module.add_def(name)
def add_param(self, name):
name = self.mangle(name)
self.defs[name] = 1
self.params[name] = 1
def get_names(self):
d = {}
d.update(self.defs)
d.update(self.uses)
d.update(self.globals)
return d.keys()
def add_child(self, child):
self.children.append(child)
def get_children(self):
return self.children
def DEBUG(self):
print >> sys.stderr, self.name, self.nested and "nested" or ""
print >> sys.stderr, "\tglobals: ", self.globals
print >> sys.stderr, "\tcells: ", self.cells
print >> sys.stderr, "\tdefs: ", self.defs
print >> sys.stderr, "\tuses: ", self.uses
print >> sys.stderr, "\tfrees:", self.frees
def check_name(self, name):
"""Return scope of name.
The scope of a name could be LOCAL, GLOBAL, FREE, or CELL.
"""
if name in self.globals:
return SC_GLOBAL_EXPLICT
if name in self.cells:
return SC_CELL
if name in self.defs:
return SC_LOCAL
if self.nested and (name in self.frees or name in self.uses):
return SC_FREE
if self.nested:
return SC_UNKNOWN
else:
return SC_GLOBAL_IMPLICIT
def get_free_vars(self):
if not self.nested:
return ()
free = {}
free.update(self.frees)
for name in self.uses.keys():
if name not in self.defs and name not in self.globals:
free[name] = 1
return free.keys()
def handle_children(self):
for child in self.children:
frees = child.get_free_vars()
globals = self.add_frees(frees)
for name in globals:
child.force_global(name)
def force_global(self, name):
"""Force name to be global in scope.
Some child of the current node had a free reference to name.
When the child was processed, it was labelled a free
variable. Now that all its enclosing scope have been
processed, the name is known to be a global or builtin. So
walk back down the child chain and set the name to be global
rather than free.
Be careful to stop if a child does not think the name is
free.
"""
self.globals[name] = 1
if name in self.frees:
del self.frees[name]
for child in self.children:
if child.check_name(name) == SC_FREE:
child.force_global(name)
def add_frees(self, names):
"""Process list of free vars from nested scope.
Returns a list of names that are either 1) declared global in the
parent or 2) undefined in a top-level parent. In either case,
the nested scope should treat them as globals.
"""
child_globals = []
for name in names:
sc = self.check_name(name)
if self.nested:
if sc == SC_UNKNOWN or sc == SC_FREE \
or isinstance(self, ClassScope):
self.frees[name] = 1
elif sc == SC_GLOBAL_IMPLICIT:
child_globals.append(name)
elif isinstance(self, FunctionScope) and sc == SC_LOCAL:
self.cells[name] = 1
elif sc != SC_CELL:
child_globals.append(name)
else:
if sc == SC_LOCAL:
self.cells[name] = 1
elif sc != SC_CELL:
child_globals.append(name)
return child_globals
def get_cell_vars(self):
return self.cells.keys()
class ModuleScope(Scope):
__super_init = Scope.__init__
def __init__(self):
self.__super_init("global", self)
class FunctionScope(Scope):
pass
class GenExprScope(Scope):
__super_init = Scope.__init__
__counter = 1
def __init__(self, module, klass=None):
i = self.__counter
self.__counter += 1
self.__super_init("generator expression<%d>"%i, module, klass)
self.add_param('.0')
def get_names(self):
keys = Scope.get_names(self)
return keys
class LambdaScope(FunctionScope):
__super_init = Scope.__init__
__counter = 1
def __init__(self, module, klass=None):
i = self.__counter
self.__counter += 1
self.__super_init("lambda.%d" % i, module, klass)
class ClassScope(Scope):
__super_init = Scope.__init__
def __init__(self, name, module):
self.__super_init(name, module, name)
class SymbolVisitor:
def __init__(self):
self.scopes = {}
self.klass = None
# node that define new scopes
def visitModule(self, node):
scope = self.module = self.scopes[node] = ModuleScope()
self.visit(node.node, scope)
visitExpression = visitModule
def visitFunction(self, node, parent):
if node.decorators:
self.visit(node.decorators, parent)
parent.add_def(node.name)
for n in node.defaults:
self.visit(n, parent)
scope = FunctionScope(node.name, self.module, self.klass)
if parent.nested or isinstance(parent, FunctionScope):
scope.nested = 1
self.scopes[node] = scope
self._do_args(scope, node.argnames)
self.visit(node.code, scope)
self.handle_free_vars(scope, parent)
def visitGenExpr(self, node, parent):
scope = GenExprScope(self.module, self.klass);
if parent.nested or isinstance(parent, FunctionScope) \
or isinstance(parent, GenExprScope):
scope.nested = 1
self.scopes[node] = scope
self.visit(node.code, scope)
self.handle_free_vars(scope, parent)
def visitGenExprInner(self, node, scope):
for genfor in node.quals:
self.visit(genfor, scope)
self.visit(node.expr, scope)
def visitGenExprFor(self, node, scope):
self.visit(node.assign, scope, 1)
self.visit(node.iter, scope)
for if_ in node.ifs:
self.visit(if_, scope)
def visitGenExprIf(self, node, scope):
self.visit(node.test, scope)
def visitLambda(self, node, parent, assign=0):
# Lambda is an expression, so it could appear in an expression
# context where assign is passed. The transformer should catch
# any code that has a lambda on the left-hand side.
assert not assign
for n in node.defaults:
self.visit(n, parent)
scope = LambdaScope(self.module, self.klass)
if parent.nested or isinstance(parent, FunctionScope):
scope.nested = 1
self.scopes[node] = scope
self._do_args(scope, node.argnames)
self.visit(node.code, scope)
self.handle_free_vars(scope, parent)
def _do_args(self, scope, args):
for name in args:
if type(name) == types.TupleType:
self._do_args(scope, name)
else:
scope.add_param(name)
def handle_free_vars(self, scope, parent):
parent.add_child(scope)
scope.handle_children()
def visitClass(self, node, parent):
parent.add_def(node.name)
for n in node.bases:
self.visit(n, parent)
scope = ClassScope(node.name, self.module)
if parent.nested or isinstance(parent, FunctionScope):
scope.nested = 1
if node.doc is not None:
scope.add_def('__doc__')
scope.add_def('__module__')
self.scopes[node] = scope
prev = self.klass
self.klass = node.name
self.visit(node.code, scope)
self.klass = prev
self.handle_free_vars(scope, parent)
# name can be a def or a use
# XXX a few calls and nodes expect a third "assign" arg that is
# true if the name is being used as an assignment. only
# expressions contained within statements may have the assign arg.
def visitName(self, node, scope, assign=0):
if assign:
scope.add_def(node.name)
else:
scope.add_use(node.name)
# operations that bind new names
def visitFor(self, node, scope):
self.visit(node.assign, scope, 1)
self.visit(node.list, scope)
self.visit(node.body, scope)
if node.else_:
self.visit(node.else_, scope)
def visitFrom(self, node, scope):
for name, asname in node.names:
if name == "*":
continue
scope.add_def(asname or name)
def visitImport(self, node, scope):
for name, asname in node.names:
i = name.find(".")
if i > -1:
name = name[:i]
scope.add_def(asname or name)
def visitGlobal(self, node, scope):
for name in node.names:
scope.add_global(name)
def visitAssign(self, node, scope):
"""Propagate assignment flag down to child nodes.
The Assign node doesn't itself contains the variables being
assigned to. Instead, the children in node.nodes are visited
with the assign flag set to true. When the names occur in
those nodes, they are marked as defs.
Some names that occur in an assignment target are not bound by
the assignment, e.g. a name occurring inside a slice. The
visitor handles these nodes specially; they do not propagate
the assign flag to their children.
"""
for n in node.nodes:
self.visit(n, scope, 1)
self.visit(node.expr, scope)
def visitAssName(self, node, scope, assign=1):
scope.add_def(node.name)
def visitAssAttr(self, node, scope, assign=0):
self.visit(node.expr, scope, 0)
def visitSubscript(self, node, scope, assign=0):
self.visit(node.expr, scope, 0)
for n in node.subs:
self.visit(n, scope, 0)
def visitSlice(self, node, scope, assign=0):
self.visit(node.expr, scope, 0)
if node.lower:
self.visit(node.lower, scope, 0)
if node.upper:
self.visit(node.upper, scope, 0)
def visitAugAssign(self, node, scope):
# If the LHS is a name, then this counts as assignment.
# Otherwise, it's just use.
self.visit(node.node, scope)
if isinstance(node.node, ast.Name):
self.visit(node.node, scope, 1) # XXX worry about this
self.visit(node.expr, scope)
# prune if statements if tests are false
_const_types = types.StringType, types.IntType, types.FloatType
def visitIf(self, node, scope):
for test, body in node.tests:
if isinstance(test, ast.Const):
if type(test.value) in self._const_types:
if not test.value:
continue
self.visit(test, scope)
self.visit(body, scope)
if node.else_:
self.visit(node.else_, scope)
# a yield statement signals a generator
def visitYield(self, node, scope):
scope.generator = 1
self.visit(node.value, scope)
def list_eq(l1, l2):
return sorted(l1) == sorted(l2)
if __name__ == "__main__":
import sys
from compiler import parseFile, walk
import symtable
def get_names(syms):
return [s for s in [s.get_name() for s in syms.get_symbols()]
if not (s.startswith('_[') or s.startswith('.'))]
for file in sys.argv[1:]:
print file
f = open(file)
buf = f.read()
f.close()
syms = symtable.symtable(buf, file, "exec")
mod_names = get_names(syms)
tree = parseFile(file)
s = SymbolVisitor()
walk(tree, s)
# compare module-level symbols
names2 = s.scopes[tree].get_names()
if not list_eq(mod_names, names2):
print
print "oops", file
print sorted(mod_names)
print sorted(names2)
sys.exit(-1)
d = {}
d.update(s.scopes)
del d[tree]
scopes = d.values()
del d
for s in syms.get_symbols():
if s.is_namespace():
l = [sc for sc in scopes
if sc.name == s.get_name()]
if len(l) > 1:
print "skipping", s.get_name()
else:
if not list_eq(get_names(s.get_namespace()),
l[0].get_names()):
print s.get_name()
print sorted(get_names(s.get_namespace()))
print sorted(l[0].get_names())
sys.exit(-1)

View File

@@ -0,0 +1,46 @@
"""Check for errs in the AST.
The Python parser does not catch all syntax errors. Others, like
assignments with invalid targets, are caught in the code generation
phase.
The compiler package catches some errors in the transformer module.
But it seems clearer to write checkers that use the AST to detect
errors.
"""
from compiler import ast, walk
def check(tree, multi=None):
v = SyntaxErrorChecker(multi)
walk(tree, v)
return v.errors
class SyntaxErrorChecker:
"""A visitor to find syntax errors in the AST."""
def __init__(self, multi=None):
"""Create new visitor object.
If optional argument multi is not None, then print messages
for each error rather than raising a SyntaxError for the
first.
"""
self.multi = multi
self.errors = 0
def error(self, node, msg):
self.errors = self.errors + 1
if self.multi is not None:
print "%s:%s: %s" % (node.filename, node.lineno, msg)
else:
raise SyntaxError, "%s (%s:%s)" % (msg, node.filename, node.lineno)
def visitAssign(self, node):
# the transformer module handles many of these
pass
## for target in node.nodes:
## if isinstance(target, ast.AssList):
## if target.lineno is None:
## target.lineno = node.lineno
## self.error(target, "can't assign to list comprehension")

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,113 @@
from compiler import ast
# XXX should probably rename ASTVisitor to ASTWalker
# XXX can it be made even more generic?
class ASTVisitor:
"""Performs a depth-first walk of the AST
The ASTVisitor will walk the AST, performing either a preorder or
postorder traversal depending on which method is called.
methods:
preorder(tree, visitor)
postorder(tree, visitor)
tree: an instance of ast.Node
visitor: an instance with visitXXX methods
The ASTVisitor is responsible for walking over the tree in the
correct order. For each node, it checks the visitor argument for
a method named 'visitNodeType' where NodeType is the name of the
node's class, e.g. Class. If the method exists, it is called
with the node as its sole argument.
The visitor method for a particular node type can control how
child nodes are visited during a preorder walk. (It can't control
the order during a postorder walk, because it is called _after_
the walk has occurred.) The ASTVisitor modifies the visitor
argument by adding a visit method to the visitor; this method can
be used to visit a child node of arbitrary type.
"""
VERBOSE = 0
def __init__(self):
self.node = None
self._cache = {}
def default(self, node, *args):
for child in node.getChildNodes():
self.dispatch(child, *args)
def dispatch(self, node, *args):
self.node = node
klass = node.__class__
meth = self._cache.get(klass, None)
if meth is None:
className = klass.__name__
meth = getattr(self.visitor, 'visit' + className, self.default)
self._cache[klass] = meth
## if self.VERBOSE > 0:
## className = klass.__name__
## if self.VERBOSE == 1:
## if meth == 0:
## print "dispatch", className
## else:
## print "dispatch", className, (meth and meth.__name__ or '')
return meth(node, *args)
def preorder(self, tree, visitor, *args):
"""Do preorder walk of tree using visitor"""
self.visitor = visitor
visitor.visit = self.dispatch
self.dispatch(tree, *args) # XXX *args make sense?
class ExampleASTVisitor(ASTVisitor):
"""Prints examples of the nodes that aren't visited
This visitor-driver is only useful for development, when it's
helpful to develop a visitor incrementally, and get feedback on what
you still have to do.
"""
examples = {}
def dispatch(self, node, *args):
self.node = node
meth = self._cache.get(node.__class__, None)
className = node.__class__.__name__
if meth is None:
meth = getattr(self.visitor, 'visit' + className, 0)
self._cache[node.__class__] = meth
if self.VERBOSE > 1:
print "dispatch", className, (meth and meth.__name__ or '')
if meth:
meth(node, *args)
elif self.VERBOSE > 0:
klass = node.__class__
if klass not in self.examples:
self.examples[klass] = klass
print
print self.visitor
print klass
for attr in dir(node):
if attr[0] != '_':
print "\t", "%-12.12s" % attr, getattr(node, attr)
print
return self.default(node, *args)
# XXX this is an API change
_walker = ASTVisitor
def walk(tree, visitor, walker=None, verbose=None):
if walker is None:
walker = _walker()
if verbose is not None:
walker.VERBOSE = verbose
walker.preorder(tree, visitor)
return walker.visitor
def dumpNode(node):
print node.__class__
for attr in dir(node):
if attr[0] != '_':
print "\t", "%-10.10s" % attr, getattr(node, attr)

View File

@@ -0,0 +1,154 @@
"""Utilities for with-statement contexts. See PEP 343."""
import sys
from functools import wraps
from warnings import warn
__all__ = ["contextmanager", "nested", "closing"]
class GeneratorContextManager(object):
"""Helper for @contextmanager decorator."""
def __init__(self, gen):
self.gen = gen
def __enter__(self):
try:
return self.gen.next()
except StopIteration:
raise RuntimeError("generator didn't yield")
def __exit__(self, type, value, traceback):
if type is None:
try:
self.gen.next()
except StopIteration:
return
else:
raise RuntimeError("generator didn't stop")
else:
if value is None:
# Need to force instantiation so we can reliably
# tell if we get the same exception back
value = type()
try:
self.gen.throw(type, value, traceback)
raise RuntimeError("generator didn't stop after throw()")
except StopIteration, exc:
# Suppress the exception *unless* it's the same exception that
# was passed to throw(). This prevents a StopIteration
# raised inside the "with" statement from being suppressed
return exc is not value
except:
# only re-raise if it's *not* the exception that was
# passed to throw(), because __exit__() must not raise
# an exception unless __exit__() itself failed. But throw()
# has to raise the exception to signal propagation, so this
# fixes the impedance mismatch between the throw() protocol
# and the __exit__() protocol.
#
if sys.exc_info()[1] is not value:
raise
def contextmanager(func):
"""@contextmanager decorator.
Typical usage:
@contextmanager
def some_generator(<arguments>):
<setup>
try:
yield <value>
finally:
<cleanup>
This makes this:
with some_generator(<arguments>) as <variable>:
<body>
equivalent to this:
<setup>
try:
<variable> = <value>
<body>
finally:
<cleanup>
"""
@wraps(func)
def helper(*args, **kwds):
return GeneratorContextManager(func(*args, **kwds))
return helper
@contextmanager
def nested(*managers):
"""Combine multiple context managers into a single nested context manager.
This function has been deprecated in favour of the multiple manager form
of the with statement.
The one advantage of this function over the multiple manager form of the
with statement is that argument unpacking allows it to be
used with a variable number of context managers as follows:
with nested(*managers):
do_something()
"""
warn("With-statements now directly support multiple context managers",
DeprecationWarning, 3)
exits = []
vars = []
exc = (None, None, None)
try:
for mgr in managers:
exit = mgr.__exit__
enter = mgr.__enter__
vars.append(enter())
exits.append(exit)
yield vars
except:
exc = sys.exc_info()
finally:
while exits:
exit = exits.pop()
try:
if exit(*exc):
exc = (None, None, None)
except:
exc = sys.exc_info()
if exc != (None, None, None):
# Don't rely on sys.exc_info() still containing
# the right information. Another exception may
# have been raised and caught by an exit method
raise exc[0], exc[1], exc[2]
class closing(object):
"""Context to automatically close something at the end of a block.
Code like this:
with closing(<module>.open(<arguments>)) as f:
<block>
is equivalent to this:
f = <module>.open(<arguments>)
try:
<block>
finally:
f.close()
"""
def __init__(self, thing):
self.thing = thing
def __enter__(self):
return self.thing
def __exit__(self, *exc_info):
self.thing.close()

File diff suppressed because it is too large Load Diff

433
src/python/stdlib/copy.py Normal file
View File

@@ -0,0 +1,433 @@
"""Generic (shallow and deep) copying operations.
Interface summary:
import copy
x = copy.copy(y) # make a shallow copy of y
x = copy.deepcopy(y) # make a deep copy of y
For module specific errors, copy.Error is raised.
The difference between shallow and deep copying is only relevant for
compound objects (objects that contain other objects, like lists or
class instances).
- A shallow copy constructs a new compound object and then (to the
extent possible) inserts *the same objects* into it that the
original contains.
- A deep copy constructs a new compound object and then, recursively,
inserts *copies* into it of the objects found in the original.
Two problems often exist with deep copy operations that don't exist
with shallow copy operations:
a) recursive objects (compound objects that, directly or indirectly,
contain a reference to themselves) may cause a recursive loop
b) because deep copy copies *everything* it may copy too much, e.g.
administrative data structures that should be shared even between
copies
Python's deep copy operation avoids these problems by:
a) keeping a table of objects already copied during the current
copying pass
b) letting user-defined classes override the copying operation or the
set of components copied
This version does not copy types like module, class, function, method,
nor stack trace, stack frame, nor file, socket, window, nor array, nor
any similar types.
Classes can use the same interfaces to control copying that they use
to control pickling: they can define methods called __getinitargs__(),
__getstate__() and __setstate__(). See the documentation for module
"pickle" for information on these methods.
"""
import types
import weakref
from copy_reg import dispatch_table
class Error(Exception):
pass
error = Error # backward compatibility
try:
from org.python.core import PyStringMap
except ImportError:
PyStringMap = None
__all__ = ["Error", "copy", "deepcopy"]
def copy(x):
"""Shallow copy operation on arbitrary Python objects.
See the module's __doc__ string for more info.
"""
cls = type(x)
copier = _copy_dispatch.get(cls)
if copier:
return copier(x)
copier = getattr(cls, "__copy__", None)
if copier:
return copier(x)
reductor = dispatch_table.get(cls)
if reductor:
rv = reductor(x)
else:
reductor = getattr(x, "__reduce_ex__", None)
if reductor:
rv = reductor(2)
else:
reductor = getattr(x, "__reduce__", None)
if reductor:
rv = reductor()
else:
raise Error("un(shallow)copyable object of type %s" % cls)
return _reconstruct(x, rv, 0)
_copy_dispatch = d = {}
def _copy_immutable(x):
return x
for t in (type(None), int, long, float, bool, str, tuple,
frozenset, type, xrange, types.ClassType,
types.BuiltinFunctionType, type(Ellipsis),
types.FunctionType, weakref.ref):
d[t] = _copy_immutable
for name in ("ComplexType", "UnicodeType", "CodeType"):
t = getattr(types, name, None)
if t is not None:
d[t] = _copy_immutable
def _copy_with_constructor(x):
return type(x)(x)
for t in (list, dict, set):
d[t] = _copy_with_constructor
def _copy_with_copy_method(x):
return x.copy()
if PyStringMap is not None:
d[PyStringMap] = _copy_with_copy_method
def _copy_inst(x):
if hasattr(x, '__copy__'):
return x.__copy__()
if hasattr(x, '__getinitargs__'):
args = x.__getinitargs__()
y = x.__class__(*args)
else:
y = _EmptyClass()
y.__class__ = x.__class__
if hasattr(x, '__getstate__'):
state = x.__getstate__()
else:
state = x.__dict__
if hasattr(y, '__setstate__'):
y.__setstate__(state)
else:
y.__dict__.update(state)
return y
d[types.InstanceType] = _copy_inst
del d
def deepcopy(x, memo=None, _nil=[]):
"""Deep copy operation on arbitrary Python objects.
See the module's __doc__ string for more info.
"""
if memo is None:
memo = {}
d = id(x)
y = memo.get(d, _nil)
if y is not _nil:
return y
cls = type(x)
copier = _deepcopy_dispatch.get(cls)
if copier:
y = copier(x, memo)
else:
try:
issc = issubclass(cls, type)
except TypeError: # cls is not a class (old Boost; see SF #502085)
issc = 0
if issc:
y = _deepcopy_atomic(x, memo)
else:
copier = getattr(x, "__deepcopy__", None)
if copier:
y = copier(memo)
else:
reductor = dispatch_table.get(cls)
if reductor:
rv = reductor(x)
else:
reductor = getattr(x, "__reduce_ex__", None)
if reductor:
rv = reductor(2)
else:
reductor = getattr(x, "__reduce__", None)
if reductor:
rv = reductor()
else:
raise Error(
"un(deep)copyable object of type %s" % cls)
y = _reconstruct(x, rv, 1, memo)
memo[d] = y
_keep_alive(x, memo) # Make sure x lives at least as long as d
return y
_deepcopy_dispatch = d = {}
def _deepcopy_atomic(x, memo):
return x
d[type(None)] = _deepcopy_atomic
d[type(Ellipsis)] = _deepcopy_atomic
d[int] = _deepcopy_atomic
d[long] = _deepcopy_atomic
d[float] = _deepcopy_atomic
d[bool] = _deepcopy_atomic
try:
d[complex] = _deepcopy_atomic
except NameError:
pass
d[str] = _deepcopy_atomic
try:
d[unicode] = _deepcopy_atomic
except NameError:
pass
try:
d[types.CodeType] = _deepcopy_atomic
except AttributeError:
pass
d[type] = _deepcopy_atomic
d[xrange] = _deepcopy_atomic
d[types.ClassType] = _deepcopy_atomic
d[types.BuiltinFunctionType] = _deepcopy_atomic
d[types.FunctionType] = _deepcopy_atomic
d[weakref.ref] = _deepcopy_atomic
def _deepcopy_list(x, memo):
y = []
memo[id(x)] = y
for a in x:
y.append(deepcopy(a, memo))
return y
d[list] = _deepcopy_list
def _deepcopy_tuple(x, memo):
y = []
for a in x:
y.append(deepcopy(a, memo))
d = id(x)
try:
return memo[d]
except KeyError:
pass
for i in range(len(x)):
if x[i] is not y[i]:
y = tuple(y)
break
else:
y = x
memo[d] = y
return y
d[tuple] = _deepcopy_tuple
def _deepcopy_dict(x, memo):
y = {}
memo[id(x)] = y
for key, value in x.iteritems():
y[deepcopy(key, memo)] = deepcopy(value, memo)
return y
d[dict] = _deepcopy_dict
if PyStringMap is not None:
d[PyStringMap] = _deepcopy_dict
def _deepcopy_method(x, memo): # Copy instance methods
return type(x)(x.im_func, deepcopy(x.im_self, memo), x.im_class)
_deepcopy_dispatch[types.MethodType] = _deepcopy_method
def _keep_alive(x, memo):
"""Keeps a reference to the object x in the memo.
Because we remember objects by their id, we have
to assure that possibly temporary objects are kept
alive by referencing them.
We store a reference at the id of the memo, which should
normally not be used unless someone tries to deepcopy
the memo itself...
"""
try:
memo[id(memo)].append(x)
except KeyError:
# aha, this is the first one :-)
memo[id(memo)]=[x]
def _deepcopy_inst(x, memo):
if hasattr(x, '__deepcopy__'):
return x.__deepcopy__(memo)
if hasattr(x, '__getinitargs__'):
args = x.__getinitargs__()
args = deepcopy(args, memo)
y = x.__class__(*args)
else:
y = _EmptyClass()
y.__class__ = x.__class__
memo[id(x)] = y
if hasattr(x, '__getstate__'):
state = x.__getstate__()
else:
state = x.__dict__
state = deepcopy(state, memo)
if hasattr(y, '__setstate__'):
y.__setstate__(state)
else:
y.__dict__.update(state)
return y
d[types.InstanceType] = _deepcopy_inst
def _reconstruct(x, info, deep, memo=None):
if isinstance(info, str):
return x
assert isinstance(info, tuple)
if memo is None:
memo = {}
n = len(info)
assert n in (2, 3, 4, 5)
callable, args = info[:2]
if n > 2:
state = info[2]
else:
state = {}
if n > 3:
listiter = info[3]
else:
listiter = None
if n > 4:
dictiter = info[4]
else:
dictiter = None
if deep:
args = deepcopy(args, memo)
y = callable(*args)
memo[id(x)] = y
if state:
if deep:
state = deepcopy(state, memo)
if hasattr(y, '__setstate__'):
y.__setstate__(state)
else:
if isinstance(state, tuple) and len(state) == 2:
state, slotstate = state
else:
slotstate = None
if state is not None:
y.__dict__.update(state)
if slotstate is not None:
for key, value in slotstate.iteritems():
setattr(y, key, value)
if listiter is not None:
for item in listiter:
if deep:
item = deepcopy(item, memo)
y.append(item)
if dictiter is not None:
for key, value in dictiter:
if deep:
key = deepcopy(key, memo)
value = deepcopy(value, memo)
y[key] = value
return y
del d
del types
# Helper for instance creation without calling __init__
class _EmptyClass:
pass
def _test():
l = [None, 1, 2L, 3.14, 'xyzzy', (1, 2L), [3.14, 'abc'],
{'abc': 'ABC'}, (), [], {}]
l1 = copy(l)
print l1==l
l1 = map(copy, l)
print l1==l
l1 = deepcopy(l)
print l1==l
class C:
def __init__(self, arg=None):
self.a = 1
self.arg = arg
if __name__ == '__main__':
import sys
file = sys.argv[0]
else:
file = __file__
self.fp = open(file)
self.fp.close()
def __getstate__(self):
return {'a': self.a, 'arg': self.arg}
def __setstate__(self, state):
for key, value in state.iteritems():
setattr(self, key, value)
def __deepcopy__(self, memo=None):
new = self.__class__(deepcopy(self.arg, memo))
new.a = self.a
return new
c = C('argument sketch')
l.append(c)
l2 = copy(l)
print l == l2
print l
print l2
l2 = deepcopy(l)
print l == l2
print l
print l2
l.append({l[1]: l, 'xyz': l[2]})
l3 = copy(l)
import repr
print map(repr.repr, l)
print map(repr.repr, l1)
print map(repr.repr, l2)
print map(repr.repr, l3)
l3 = deepcopy(l)
import repr
print map(repr.repr, l)
print map(repr.repr, l1)
print map(repr.repr, l2)
print map(repr.repr, l3)
class odict(dict):
def __init__(self, d = {}):
self.a = 99
dict.__init__(self, d)
def __setitem__(self, k, i):
dict.__setitem__(self, k, i)
self.a
o = odict({"A" : "B"})
x = deepcopy(o)
print(o, x)
if __name__ == '__main__':
_test()

View File

@@ -0,0 +1,201 @@
"""Helper to provide extensibility for pickle/cPickle.
This is only useful to add pickle support for extension types defined in
C, not for instances of user-defined classes.
"""
from types import ClassType as _ClassType
__all__ = ["pickle", "constructor",
"add_extension", "remove_extension", "clear_extension_cache"]
dispatch_table = {}
def pickle(ob_type, pickle_function, constructor_ob=None):
if type(ob_type) is _ClassType:
raise TypeError("copy_reg is not intended for use with classes")
if not hasattr(pickle_function, '__call__'):
raise TypeError("reduction functions must be callable")
dispatch_table[ob_type] = pickle_function
# The constructor_ob function is a vestige of safe for unpickling.
# There is no reason for the caller to pass it anymore.
if constructor_ob is not None:
constructor(constructor_ob)
def constructor(object):
if not hasattr(object, '__call__'):
raise TypeError("constructors must be callable")
# Example: provide pickling support for complex numbers.
try:
complex
except NameError:
pass
else:
def pickle_complex(c):
return complex, (c.real, c.imag)
pickle(complex, pickle_complex, complex)
# Support for pickling new-style objects
def _reconstructor(cls, base, state):
if base is object:
obj = object.__new__(cls)
else:
obj = base.__new__(cls, state)
if base.__init__ != object.__init__:
base.__init__(obj, state)
return obj
_HEAPTYPE = 1<<9
# Python code for object.__reduce_ex__ for protocols 0 and 1
def _reduce_ex(self, proto):
assert proto < 2
for base in self.__class__.__mro__:
if hasattr(base, '__flags__') and not base.__flags__ & _HEAPTYPE:
break
else:
base = object # not really reachable
if base is object:
state = None
else:
if base is self.__class__:
raise TypeError, "can't pickle %s objects" % base.__name__
state = base(self)
args = (self.__class__, base, state)
try:
getstate = self.__getstate__
except AttributeError:
if getattr(self, "__slots__", None):
raise TypeError("a class that defines __slots__ without "
"defining __getstate__ cannot be pickled")
try:
dict = self.__dict__
except AttributeError:
dict = None
else:
dict = getstate()
if dict:
return _reconstructor, args, dict
else:
return _reconstructor, args
# Helper for __reduce_ex__ protocol 2
def __newobj__(cls, *args):
return cls.__new__(cls, *args)
def _slotnames(cls):
"""Return a list of slot names for a given class.
This needs to find slots defined by the class and its bases, so we
can't simply return the __slots__ attribute. We must walk down
the Method Resolution Order and concatenate the __slots__ of each
class found there. (This assumes classes don't modify their
__slots__ attribute to misrepresent their slots after the class is
defined.)
"""
# Get the value from a cache in the class if possible
names = cls.__dict__.get("__slotnames__")
if names is not None:
return names
# Not cached -- calculate the value
names = []
if not hasattr(cls, "__slots__"):
# This class has no slots
pass
else:
# Slots found -- gather slot names from all base classes
for c in cls.__mro__:
if "__slots__" in c.__dict__:
slots = c.__dict__['__slots__']
# if class has a single slot, it can be given as a string
if isinstance(slots, basestring):
slots = (slots,)
for name in slots:
# special descriptors
if name in ("__dict__", "__weakref__"):
continue
# mangled names
elif name.startswith('__') and not name.endswith('__'):
names.append('_%s%s' % (c.__name__, name))
else:
names.append(name)
# Cache the outcome in the class if at all possible
try:
cls.__slotnames__ = names
except:
pass # But don't die if we can't
return names
# A registry of extension codes. This is an ad-hoc compression
# mechanism. Whenever a global reference to <module>, <name> is about
# to be pickled, the (<module>, <name>) tuple is looked up here to see
# if it is a registered extension code for it. Extension codes are
# universal, so that the meaning of a pickle does not depend on
# context. (There are also some codes reserved for local use that
# don't have this restriction.) Codes are positive ints; 0 is
# reserved.
_extension_registry = {} # key -> code
_inverted_registry = {} # code -> key
_extension_cache = {} # code -> object
# Don't ever rebind those names: cPickle grabs a reference to them when
# it's initialized, and won't see a rebinding.
def add_extension(module, name, code):
"""Register an extension code."""
code = int(code)
if not 1 <= code <= 0x7fffffff:
raise ValueError, "code out of range"
key = (module, name)
if (_extension_registry.get(key) == code and
_inverted_registry.get(code) == key):
return # Redundant registrations are benign
if key in _extension_registry:
raise ValueError("key %s is already registered with code %s" %
(key, _extension_registry[key]))
if code in _inverted_registry:
raise ValueError("code %s is already in use for key %s" %
(code, _inverted_registry[code]))
_extension_registry[key] = code
_inverted_registry[code] = key
def remove_extension(module, name, code):
"""Unregister an extension code. For testing only."""
key = (module, name)
if (_extension_registry.get(key) != code or
_inverted_registry.get(code) != key):
raise ValueError("key %s is not registered with code %s" %
(key, code))
del _extension_registry[key]
del _inverted_registry[code]
if code in _extension_cache:
del _extension_cache[code]
def clear_extension_cache():
_extension_cache.clear()
# Standard extension code assignments
# Reserved ranges
# First Last Count Purpose
# 1 127 127 Reserved for Python standard library
# 128 191 64 Reserved for Zope
# 192 239 48 Reserved for 3rd parties
# 240 255 16 Reserved for private use (will never be assigned)
# 256 Inf Inf Reserved for future assignment
# Extension codes are assigned by the Python Software Foundation.

451
src/python/stdlib/csv.py Normal file
View File

@@ -0,0 +1,451 @@
"""
csv.py - read/write/investigate CSV files
"""
import re
from functools import reduce
from _csv import Error, __version__, writer, reader, register_dialect, \
unregister_dialect, get_dialect, list_dialects, \
field_size_limit, \
QUOTE_MINIMAL, QUOTE_ALL, QUOTE_NONNUMERIC, QUOTE_NONE, \
__doc__
from _csv import Dialect as _Dialect
try:
from cStringIO import StringIO
except ImportError:
from StringIO import StringIO
__all__ = [ "QUOTE_MINIMAL", "QUOTE_ALL", "QUOTE_NONNUMERIC", "QUOTE_NONE",
"Error", "Dialect", "__doc__", "excel", "excel_tab",
"field_size_limit", "reader", "writer",
"register_dialect", "get_dialect", "list_dialects", "Sniffer",
"unregister_dialect", "__version__", "DictReader", "DictWriter" ]
class Dialect:
"""Describe an Excel dialect.
This must be subclassed (see csv.excel). Valid attributes are:
delimiter, quotechar, escapechar, doublequote, skipinitialspace,
lineterminator, quoting.
"""
_name = ""
_valid = False
# placeholders
delimiter = None
quotechar = None
escapechar = None
doublequote = None
skipinitialspace = None
lineterminator = None
quoting = None
def __init__(self):
if self.__class__ != Dialect:
self._valid = True
self._validate()
def _validate(self):
try:
_Dialect(self)
except TypeError, e:
# We do this for compatibility with py2.3
raise Error(str(e))
class excel(Dialect):
"""Describe the usual properties of Excel-generated CSV files."""
delimiter = ','
quotechar = '"'
doublequote = True
skipinitialspace = False
lineterminator = '\r\n'
quoting = QUOTE_MINIMAL
register_dialect("excel", excel)
class excel_tab(excel):
"""Describe the usual properties of Excel-generated TAB-delimited files."""
delimiter = '\t'
register_dialect("excel-tab", excel_tab)
class DictReader:
def __init__(self, f, fieldnames=None, restkey=None, restval=None,
dialect="excel", *args, **kwds):
self._fieldnames = fieldnames # list of keys for the dict
self.restkey = restkey # key to catch long rows
self.restval = restval # default value for short rows
self.reader = reader(f, dialect, *args, **kwds)
self.dialect = dialect
self.line_num = 0
def __iter__(self):
return self
@property
def fieldnames(self):
if self._fieldnames is None:
try:
self._fieldnames = self.reader.next()
except StopIteration:
pass
self.line_num = self.reader.line_num
return self._fieldnames
@fieldnames.setter
def fieldnames(self, value):
self._fieldnames = value
def next(self):
if self.line_num == 0:
# Used only for its side effect.
self.fieldnames
row = self.reader.next()
self.line_num = self.reader.line_num
# unlike the basic reader, we prefer not to return blanks,
# because we will typically wind up with a dict full of None
# values
while row == []:
row = self.reader.next()
d = dict(zip(self.fieldnames, row))
lf = len(self.fieldnames)
lr = len(row)
if lf < lr:
d[self.restkey] = row[lf:]
elif lf > lr:
for key in self.fieldnames[lr:]:
d[key] = self.restval
return d
class DictWriter:
def __init__(self, f, fieldnames, restval="", extrasaction="raise",
dialect="excel", *args, **kwds):
self.fieldnames = fieldnames # list of keys for the dict
self.restval = restval # for writing short dicts
if extrasaction.lower() not in ("raise", "ignore"):
raise ValueError, \
("extrasaction (%s) must be 'raise' or 'ignore'" %
extrasaction)
self.extrasaction = extrasaction
self.writer = writer(f, dialect, *args, **kwds)
def writeheader(self):
header = dict(zip(self.fieldnames, self.fieldnames))
self.writerow(header)
def _dict_to_list(self, rowdict):
if self.extrasaction == "raise":
wrong_fields = [k for k in rowdict if k not in self.fieldnames]
if wrong_fields:
raise ValueError("dict contains fields not in fieldnames: " +
", ".join(wrong_fields))
return [rowdict.get(key, self.restval) for key in self.fieldnames]
def writerow(self, rowdict):
return self.writer.writerow(self._dict_to_list(rowdict))
def writerows(self, rowdicts):
rows = []
for rowdict in rowdicts:
rows.append(self._dict_to_list(rowdict))
return self.writer.writerows(rows)
# Guard Sniffer's type checking against builds that exclude complex()
try:
complex
except NameError:
complex = float
class Sniffer:
'''
"Sniffs" the format of a CSV file (i.e. delimiter, quotechar)
Returns a Dialect object.
'''
def __init__(self):
# in case there is more than one possible delimiter
self.preferred = [',', '\t', ';', ' ', ':']
def sniff(self, sample, delimiters=None):
"""
Returns a dialect (or None) corresponding to the sample
"""
quotechar, doublequote, delimiter, skipinitialspace = \
self._guess_quote_and_delimiter(sample, delimiters)
if not delimiter:
delimiter, skipinitialspace = self._guess_delimiter(sample,
delimiters)
if not delimiter:
raise Error, "Could not determine delimiter"
class dialect(Dialect):
_name = "sniffed"
lineterminator = '\r\n'
quoting = QUOTE_MINIMAL
# escapechar = ''
dialect.doublequote = doublequote
dialect.delimiter = delimiter
# _csv.reader won't accept a quotechar of ''
dialect.quotechar = quotechar or '"'
dialect.skipinitialspace = skipinitialspace
return dialect
def _guess_quote_and_delimiter(self, data, delimiters):
"""
Looks for text enclosed between two identical quotes
(the probable quotechar) which are preceded and followed
by the same character (the probable delimiter).
For example:
,'some text',
The quote with the most wins, same with the delimiter.
If there is no quotechar the delimiter can't be determined
this way.
"""
matches = []
for restr in ('(?P<delim>[^\w\n"\'])(?P<space> ?)(?P<quote>["\']).*?(?P=quote)(?P=delim)', # ,".*?",
'(?:^|\n)(?P<quote>["\']).*?(?P=quote)(?P<delim>[^\w\n"\'])(?P<space> ?)', # ".*?",
'(?P<delim>>[^\w\n"\'])(?P<space> ?)(?P<quote>["\']).*?(?P=quote)(?:$|\n)', # ,".*?"
'(?:^|\n)(?P<quote>["\']).*?(?P=quote)(?:$|\n)'): # ".*?" (no delim, no space)
regexp = re.compile(restr, re.DOTALL | re.MULTILINE)
matches = regexp.findall(data)
if matches:
break
if not matches:
# (quotechar, doublequote, delimiter, skipinitialspace)
return ('', False, None, 0)
quotes = {}
delims = {}
spaces = 0
for m in matches:
n = regexp.groupindex['quote'] - 1
key = m[n]
if key:
quotes[key] = quotes.get(key, 0) + 1
try:
n = regexp.groupindex['delim'] - 1
key = m[n]
except KeyError:
continue
if key and (delimiters is None or key in delimiters):
delims[key] = delims.get(key, 0) + 1
try:
n = regexp.groupindex['space'] - 1
except KeyError:
continue
if m[n]:
spaces += 1
quotechar = reduce(lambda a, b, quotes = quotes:
(quotes[a] > quotes[b]) and a or b, quotes.keys())
if delims:
delim = reduce(lambda a, b, delims = delims:
(delims[a] > delims[b]) and a or b, delims.keys())
skipinitialspace = delims[delim] == spaces
if delim == '\n': # most likely a file with a single column
delim = ''
else:
# there is *no* delimiter, it's a single column of quoted data
delim = ''
skipinitialspace = 0
# if we see an extra quote between delimiters, we've got a
# double quoted format
dq_regexp = re.compile(r"((%(delim)s)|^)\W*%(quote)s[^%(delim)s\n]*%(quote)s[^%(delim)s\n]*%(quote)s\W*((%(delim)s)|$)" % \
{'delim':delim, 'quote':quotechar}, re.MULTILINE)
if dq_regexp.search(data):
doublequote = True
else:
doublequote = False
return (quotechar, doublequote, delim, skipinitialspace)
def _guess_delimiter(self, data, delimiters):
"""
The delimiter /should/ occur the same number of times on
each row. However, due to malformed data, it may not. We don't want
an all or nothing approach, so we allow for small variations in this
number.
1) build a table of the frequency of each character on every line.
2) build a table of freqencies of this frequency (meta-frequency?),
e.g. 'x occurred 5 times in 10 rows, 6 times in 1000 rows,
7 times in 2 rows'
3) use the mode of the meta-frequency to determine the /expected/
frequency for that character
4) find out how often the character actually meets that goal
5) the character that best meets its goal is the delimiter
For performance reasons, the data is evaluated in chunks, so it can
try and evaluate the smallest portion of the data possible, evaluating
additional chunks as necessary.
"""
data = filter(None, data.split('\n'))
ascii = [chr(c) for c in range(127)] # 7-bit ASCII
# build frequency tables
chunkLength = min(10, len(data))
iteration = 0
charFrequency = {}
modes = {}
delims = {}
start, end = 0, min(chunkLength, len(data))
while start < len(data):
iteration += 1
for line in data[start:end]:
for char in ascii:
metaFrequency = charFrequency.get(char, {})
# must count even if frequency is 0
freq = line.count(char)
# value is the mode
metaFrequency[freq] = metaFrequency.get(freq, 0) + 1
charFrequency[char] = metaFrequency
for char in charFrequency.keys():
items = charFrequency[char].items()
if len(items) == 1 and items[0][0] == 0:
continue
# get the mode of the frequencies
if len(items) > 1:
modes[char] = reduce(lambda a, b: a[1] > b[1] and a or b,
items)
# adjust the mode - subtract the sum of all
# other frequencies
items.remove(modes[char])
modes[char] = (modes[char][0], modes[char][1]
- reduce(lambda a, b: (0, a[1] + b[1]),
items)[1])
else:
modes[char] = items[0]
# build a list of possible delimiters
modeList = modes.items()
total = float(chunkLength * iteration)
# (rows of consistent data) / (number of rows) = 100%
consistency = 1.0
# minimum consistency threshold
threshold = 0.9
while len(delims) == 0 and consistency >= threshold:
for k, v in modeList:
if v[0] > 0 and v[1] > 0:
if ((v[1]/total) >= consistency and
(delimiters is None or k in delimiters)):
delims[k] = v
consistency -= 0.01
if len(delims) == 1:
delim = delims.keys()[0]
skipinitialspace = (data[0].count(delim) ==
data[0].count("%c " % delim))
return (delim, skipinitialspace)
# analyze another chunkLength lines
start = end
end += chunkLength
if not delims:
return ('', 0)
# if there's more than one, fall back to a 'preferred' list
if len(delims) > 1:
for d in self.preferred:
if d in delims.keys():
skipinitialspace = (data[0].count(d) ==
data[0].count("%c " % d))
return (d, skipinitialspace)
# nothing else indicates a preference, pick the character that
# dominates(?)
items = [(v,k) for (k,v) in delims.items()]
items.sort()
delim = items[-1][1]
skipinitialspace = (data[0].count(delim) ==
data[0].count("%c " % delim))
return (delim, skipinitialspace)
def has_header(self, sample):
# Creates a dictionary of types of data in each column. If any
# column is of a single type (say, integers), *except* for the first
# row, then the first row is presumed to be labels. If the type
# can't be determined, it is assumed to be a string in which case
# the length of the string is the determining factor: if all of the
# rows except for the first are the same length, it's a header.
# Finally, a 'vote' is taken at the end for each column, adding or
# subtracting from the likelihood of the first row being a header.
rdr = reader(StringIO(sample), self.sniff(sample))
header = rdr.next() # assume first row is header
columns = len(header)
columnTypes = {}
for i in range(columns): columnTypes[i] = None
checked = 0
for row in rdr:
# arbitrary number of rows to check, to keep it sane
if checked > 20:
break
checked += 1
if len(row) != columns:
continue # skip rows that have irregular number of columns
for col in columnTypes.keys():
for thisType in [int, long, float, complex]:
try:
thisType(row[col])
break
except (ValueError, OverflowError):
pass
else:
# fallback to length of string
thisType = len(row[col])
# treat longs as ints
if thisType == long:
thisType = int
if thisType != columnTypes[col]:
if columnTypes[col] is None: # add new column type
columnTypes[col] = thisType
else:
# type is inconsistent, remove column from
# consideration
del columnTypes[col]
# finally, compare results against first row and "vote"
# on whether it's a header
hasHeader = 0
for col, colType in columnTypes.items():
if type(colType) == type(0): # it's a length
if len(header[col]) != colType:
hasHeader += 1
else:
hasHeader -= 1
else: # attempt typecast
try:
colType(header[col])
except (ValueError, TypeError):
hasHeader += 1
else:
hasHeader -= 1
return hasHeader > 0

View File

@@ -0,0 +1,549 @@
######################################################################
# This file should be kept compatible with Python 2.3, see PEP 291. #
######################################################################
"""create and manipulate C data types in Python"""
import os as _os, sys as _sys
__version__ = "1.1.0"
from _ctypes import Union, Structure, Array
from _ctypes import _Pointer
from _ctypes import CFuncPtr as _CFuncPtr
from _ctypes import __version__ as _ctypes_version
from _ctypes import RTLD_LOCAL, RTLD_GLOBAL
from _ctypes import ArgumentError
from struct import calcsize as _calcsize
if __version__ != _ctypes_version:
raise Exception("Version number mismatch", __version__, _ctypes_version)
if _os.name in ("nt", "ce"):
from _ctypes import FormatError
DEFAULT_MODE = RTLD_LOCAL
if _os.name == "posix" and _sys.platform == "darwin":
# On OS X 10.3, we use RTLD_GLOBAL as default mode
# because RTLD_LOCAL does not work at least on some
# libraries. OS X 10.3 is Darwin 7, so we check for
# that.
if int(_os.uname()[2].split('.')[0]) < 8:
DEFAULT_MODE = RTLD_GLOBAL
from _ctypes import FUNCFLAG_CDECL as _FUNCFLAG_CDECL, \
FUNCFLAG_PYTHONAPI as _FUNCFLAG_PYTHONAPI, \
FUNCFLAG_USE_ERRNO as _FUNCFLAG_USE_ERRNO, \
FUNCFLAG_USE_LASTERROR as _FUNCFLAG_USE_LASTERROR
"""
WINOLEAPI -> HRESULT
WINOLEAPI_(type)
STDMETHODCALLTYPE
STDMETHOD(name)
STDMETHOD_(type, name)
STDAPICALLTYPE
"""
def create_string_buffer(init, size=None):
"""create_string_buffer(aString) -> character array
create_string_buffer(anInteger) -> character array
create_string_buffer(aString, anInteger) -> character array
"""
if isinstance(init, (str, unicode)):
if size is None:
size = len(init)+1
buftype = c_char * size
buf = buftype()
buf.value = init
return buf
elif isinstance(init, (int, long)):
buftype = c_char * init
buf = buftype()
return buf
raise TypeError(init)
def c_buffer(init, size=None):
## "deprecated, use create_string_buffer instead"
## import warnings
## warnings.warn("c_buffer is deprecated, use create_string_buffer instead",
## DeprecationWarning, stacklevel=2)
return create_string_buffer(init, size)
_c_functype_cache = {}
def CFUNCTYPE(restype, *argtypes, **kw):
"""CFUNCTYPE(restype, *argtypes,
use_errno=False, use_last_error=False) -> function prototype.
restype: the result type
argtypes: a sequence specifying the argument types
The function prototype can be called in different ways to create a
callable object:
prototype(integer address) -> foreign function
prototype(callable) -> create and return a C callable function from callable
prototype(integer index, method name[, paramflags]) -> foreign function calling a COM method
prototype((ordinal number, dll object)[, paramflags]) -> foreign function exported by ordinal
prototype((function name, dll object)[, paramflags]) -> foreign function exported by name
"""
flags = _FUNCFLAG_CDECL
if kw.pop("use_errno", False):
flags |= _FUNCFLAG_USE_ERRNO
if kw.pop("use_last_error", False):
flags |= _FUNCFLAG_USE_LASTERROR
if kw:
raise ValueError("unexpected keyword argument(s) %s" % kw.keys())
try:
return _c_functype_cache[(restype, argtypes, flags)]
except KeyError:
class CFunctionType(_CFuncPtr):
_argtypes_ = argtypes
_restype_ = restype
_flags_ = flags
_c_functype_cache[(restype, argtypes, flags)] = CFunctionType
return CFunctionType
if _os.name in ("nt", "ce"):
from _ctypes import LoadLibrary as _dlopen
from _ctypes import FUNCFLAG_STDCALL as _FUNCFLAG_STDCALL
if _os.name == "ce":
# 'ce' doesn't have the stdcall calling convention
_FUNCFLAG_STDCALL = _FUNCFLAG_CDECL
_win_functype_cache = {}
def WINFUNCTYPE(restype, *argtypes, **kw):
# docstring set later (very similar to CFUNCTYPE.__doc__)
flags = _FUNCFLAG_STDCALL
if kw.pop("use_errno", False):
flags |= _FUNCFLAG_USE_ERRNO
if kw.pop("use_last_error", False):
flags |= _FUNCFLAG_USE_LASTERROR
if kw:
raise ValueError("unexpected keyword argument(s) %s" % kw.keys())
try:
return _win_functype_cache[(restype, argtypes, flags)]
except KeyError:
class WinFunctionType(_CFuncPtr):
_argtypes_ = argtypes
_restype_ = restype
_flags_ = flags
_win_functype_cache[(restype, argtypes, flags)] = WinFunctionType
return WinFunctionType
if WINFUNCTYPE.__doc__:
WINFUNCTYPE.__doc__ = CFUNCTYPE.__doc__.replace("CFUNCTYPE", "WINFUNCTYPE")
elif _os.name == "posix":
from _ctypes import dlopen as _dlopen
from _ctypes import sizeof, byref, addressof, alignment, resize
from _ctypes import get_errno, set_errno
from _ctypes import _SimpleCData
def _check_size(typ, typecode=None):
# Check if sizeof(ctypes_type) against struct.calcsize. This
# should protect somewhat against a misconfigured libffi.
from struct import calcsize
if typecode is None:
# Most _type_ codes are the same as used in struct
typecode = typ._type_
actual, required = sizeof(typ), calcsize(typecode)
if actual != required:
raise SystemError("sizeof(%s) wrong: %d instead of %d" % \
(typ, actual, required))
class py_object(_SimpleCData):
_type_ = "O"
def __repr__(self):
try:
return super(py_object, self).__repr__()
except ValueError:
return "%s(<NULL>)" % type(self).__name__
_check_size(py_object, "P")
class c_short(_SimpleCData):
_type_ = "h"
_check_size(c_short)
class c_ushort(_SimpleCData):
_type_ = "H"
_check_size(c_ushort)
class c_long(_SimpleCData):
_type_ = "l"
_check_size(c_long)
class c_ulong(_SimpleCData):
_type_ = "L"
_check_size(c_ulong)
if _calcsize("i") == _calcsize("l"):
# if int and long have the same size, make c_int an alias for c_long
c_int = c_long
c_uint = c_ulong
else:
class c_int(_SimpleCData):
_type_ = "i"
_check_size(c_int)
class c_uint(_SimpleCData):
_type_ = "I"
_check_size(c_uint)
class c_float(_SimpleCData):
_type_ = "f"
_check_size(c_float)
class c_double(_SimpleCData):
_type_ = "d"
_check_size(c_double)
class c_longdouble(_SimpleCData):
_type_ = "g"
if sizeof(c_longdouble) == sizeof(c_double):
c_longdouble = c_double
if _calcsize("l") == _calcsize("q"):
# if long and long long have the same size, make c_longlong an alias for c_long
c_longlong = c_long
c_ulonglong = c_ulong
else:
class c_longlong(_SimpleCData):
_type_ = "q"
_check_size(c_longlong)
class c_ulonglong(_SimpleCData):
_type_ = "Q"
## def from_param(cls, val):
## return ('d', float(val), val)
## from_param = classmethod(from_param)
_check_size(c_ulonglong)
class c_ubyte(_SimpleCData):
_type_ = "B"
c_ubyte.__ctype_le__ = c_ubyte.__ctype_be__ = c_ubyte
# backward compatibility:
##c_uchar = c_ubyte
_check_size(c_ubyte)
class c_byte(_SimpleCData):
_type_ = "b"
c_byte.__ctype_le__ = c_byte.__ctype_be__ = c_byte
_check_size(c_byte)
class c_char(_SimpleCData):
_type_ = "c"
c_char.__ctype_le__ = c_char.__ctype_be__ = c_char
_check_size(c_char)
class c_char_p(_SimpleCData):
_type_ = "z"
if _os.name == "nt":
def __repr__(self):
if not windll.kernel32.IsBadStringPtrA(self, -1):
return "%s(%r)" % (self.__class__.__name__, self.value)
return "%s(%s)" % (self.__class__.__name__, cast(self, c_void_p).value)
else:
def __repr__(self):
return "%s(%s)" % (self.__class__.__name__, cast(self, c_void_p).value)
_check_size(c_char_p, "P")
class c_void_p(_SimpleCData):
_type_ = "P"
c_voidp = c_void_p # backwards compatibility (to a bug)
_check_size(c_void_p)
class c_bool(_SimpleCData):
_type_ = "?"
from _ctypes import POINTER, pointer, _pointer_type_cache
try:
from _ctypes import set_conversion_mode
except ImportError:
pass
else:
if _os.name in ("nt", "ce"):
set_conversion_mode("mbcs", "ignore")
else:
set_conversion_mode("ascii", "strict")
class c_wchar_p(_SimpleCData):
_type_ = "Z"
class c_wchar(_SimpleCData):
_type_ = "u"
POINTER(c_wchar).from_param = c_wchar_p.from_param #_SimpleCData.c_wchar_p_from_param
def create_unicode_buffer(init, size=None):
"""create_unicode_buffer(aString) -> character array
create_unicode_buffer(anInteger) -> character array
create_unicode_buffer(aString, anInteger) -> character array
"""
if isinstance(init, (str, unicode)):
if size is None:
size = len(init)+1
buftype = c_wchar * size
buf = buftype()
buf.value = init
return buf
elif isinstance(init, (int, long)):
buftype = c_wchar * init
buf = buftype()
return buf
raise TypeError(init)
POINTER(c_char).from_param = c_char_p.from_param #_SimpleCData.c_char_p_from_param
# XXX Deprecated
def SetPointerType(pointer, cls):
if _pointer_type_cache.get(cls, None) is not None:
raise RuntimeError("This type already exists in the cache")
if id(pointer) not in _pointer_type_cache:
raise RuntimeError("What's this???")
pointer.set_type(cls)
_pointer_type_cache[cls] = pointer
del _pointer_type_cache[id(pointer)]
# XXX Deprecated
def ARRAY(typ, len):
return typ * len
################################################################
class CDLL(object):
"""An instance of this class represents a loaded dll/shared
library, exporting functions using the standard C calling
convention (named 'cdecl' on Windows).
The exported functions can be accessed as attributes, or by
indexing with the function name. Examples:
<obj>.qsort -> callable object
<obj>['qsort'] -> callable object
Calling the functions releases the Python GIL during the call and
reacquires it afterwards.
"""
_func_flags_ = _FUNCFLAG_CDECL
_func_restype_ = c_int
def __init__(self, name, mode=DEFAULT_MODE, handle=None,
use_errno=False,
use_last_error=False):
self._name = name
flags = self._func_flags_
if use_errno:
flags |= _FUNCFLAG_USE_ERRNO
if use_last_error:
flags |= _FUNCFLAG_USE_LASTERROR
class _FuncPtr(_CFuncPtr):
_flags_ = flags
_restype_ = self._func_restype_
self._FuncPtr = _FuncPtr
if handle is None:
self._handle = _dlopen(self._name, mode)
else:
self._handle = handle
def __repr__(self):
return "<%s '%s', handle %x at %x>" % \
(self.__class__.__name__, self._name,
(self._handle & (_sys.maxint*2 + 1)),
id(self) & (_sys.maxint*2 + 1))
def __getattr__(self, name):
if name.startswith('__') and name.endswith('__'):
raise AttributeError(name)
func = self.__getitem__(name)
setattr(self, name, func)
return func
def __getitem__(self, name_or_ordinal):
func = self._FuncPtr((name_or_ordinal, self))
if not isinstance(name_or_ordinal, (int, long)):
func.__name__ = name_or_ordinal
return func
class PyDLL(CDLL):
"""This class represents the Python library itself. It allows to
access Python API functions. The GIL is not released, and
Python exceptions are handled correctly.
"""
_func_flags_ = _FUNCFLAG_CDECL | _FUNCFLAG_PYTHONAPI
if _os.name in ("nt", "ce"):
class WinDLL(CDLL):
"""This class represents a dll exporting functions using the
Windows stdcall calling convention.
"""
_func_flags_ = _FUNCFLAG_STDCALL
# XXX Hm, what about HRESULT as normal parameter?
# Mustn't it derive from c_long then?
from _ctypes import _check_HRESULT, _SimpleCData
class HRESULT(_SimpleCData):
_type_ = "l"
# _check_retval_ is called with the function's result when it
# is used as restype. It checks for the FAILED bit, and
# raises a WindowsError if it is set.
#
# The _check_retval_ method is implemented in C, so that the
# method definition itself is not included in the traceback
# when it raises an error - that is what we want (and Python
# doesn't have a way to raise an exception in the caller's
# frame).
_check_retval_ = _check_HRESULT
class OleDLL(CDLL):
"""This class represents a dll exporting functions using the
Windows stdcall calling convention, and returning HRESULT.
HRESULT error values are automatically raised as WindowsError
exceptions.
"""
_func_flags_ = _FUNCFLAG_STDCALL
_func_restype_ = HRESULT
class LibraryLoader(object):
def __init__(self, dlltype):
self._dlltype = dlltype
def __getattr__(self, name):
if name[0] == '_':
raise AttributeError(name)
dll = self._dlltype(name)
setattr(self, name, dll)
return dll
def __getitem__(self, name):
return getattr(self, name)
def LoadLibrary(self, name):
return self._dlltype(name)
cdll = LibraryLoader(CDLL)
pydll = LibraryLoader(PyDLL)
if _os.name in ("nt", "ce"):
pythonapi = PyDLL("python dll", None, _sys.dllhandle)
elif _sys.platform == "cygwin":
pythonapi = PyDLL("libpython%d.%d.dll" % _sys.version_info[:2])
else:
pythonapi = PyDLL(None)
if _os.name in ("nt", "ce"):
windll = LibraryLoader(WinDLL)
oledll = LibraryLoader(OleDLL)
if _os.name == "nt":
GetLastError = windll.kernel32.GetLastError
else:
GetLastError = windll.coredll.GetLastError
from _ctypes import get_last_error, set_last_error
def WinError(code=None, descr=None):
if code is None:
code = GetLastError()
if descr is None:
descr = FormatError(code).strip()
return WindowsError(code, descr)
_pointer_type_cache[None] = c_void_p
if sizeof(c_uint) == sizeof(c_void_p):
c_size_t = c_uint
c_ssize_t = c_int
elif sizeof(c_ulong) == sizeof(c_void_p):
c_size_t = c_ulong
c_ssize_t = c_long
elif sizeof(c_ulonglong) == sizeof(c_void_p):
c_size_t = c_ulonglong
c_ssize_t = c_longlong
# functions
from _ctypes import _memmove_addr, _memset_addr, _string_at_addr, _cast_addr
## void *memmove(void *, const void *, size_t);
memmove = CFUNCTYPE(c_void_p, c_void_p, c_void_p, c_size_t)(_memmove_addr)
## void *memset(void *, int, size_t)
memset = CFUNCTYPE(c_void_p, c_void_p, c_int, c_size_t)(_memset_addr)
def PYFUNCTYPE(restype, *argtypes):
class CFunctionType(_CFuncPtr):
_argtypes_ = argtypes
_restype_ = restype
_flags_ = _FUNCFLAG_CDECL | _FUNCFLAG_PYTHONAPI
return CFunctionType
_cast = PYFUNCTYPE(py_object, c_void_p, py_object, py_object)(_cast_addr)
def cast(obj, typ):
return _cast(obj, obj, typ)
_string_at = PYFUNCTYPE(py_object, c_void_p, c_int)(_string_at_addr)
def string_at(ptr, size=-1):
"""string_at(addr[, size]) -> string
Return the string at addr."""
return _string_at(ptr, size)
try:
from _ctypes import _wstring_at_addr
except ImportError:
pass
else:
_wstring_at = PYFUNCTYPE(py_object, c_void_p, c_int)(_wstring_at_addr)
def wstring_at(ptr, size=-1):
"""wstring_at(addr[, size]) -> string
Return the string at addr."""
return _wstring_at(ptr, size)
if _os.name in ("nt", "ce"): # COM stuff
def DllGetClassObject(rclsid, riid, ppv):
try:
ccom = __import__("comtypes.server.inprocserver", globals(), locals(), ['*'])
except ImportError:
return -2147221231 # CLASS_E_CLASSNOTAVAILABLE
else:
return ccom.DllGetClassObject(rclsid, riid, ppv)
def DllCanUnloadNow():
try:
ccom = __import__("comtypes.server.inprocserver", globals(), locals(), ['*'])
except ImportError:
return 0 # S_OK
return ccom.DllCanUnloadNow()
from ctypes._endian import BigEndianStructure, LittleEndianStructure
# Fill in specifically-sized types
c_int8 = c_byte
c_uint8 = c_ubyte
for kind in [c_short, c_int, c_long, c_longlong]:
if sizeof(kind) == 2: c_int16 = kind
elif sizeof(kind) == 4: c_int32 = kind
elif sizeof(kind) == 8: c_int64 = kind
for kind in [c_ushort, c_uint, c_ulong, c_ulonglong]:
if sizeof(kind) == 2: c_uint16 = kind
elif sizeof(kind) == 4: c_uint32 = kind
elif sizeof(kind) == 8: c_uint64 = kind
del(kind)
# XXX for whatever reasons, creating the first instance of a callback
# function is needed for the unittests on Win64 to succeed. This MAY
# be a compiler bug, since the problem occurs only when _ctypes is
# compiled with the MS SDK compiler. Or an uninitialized variable?
CFUNCTYPE(c_int)(lambda: None)

View File

@@ -0,0 +1,60 @@
######################################################################
# This file should be kept compatible with Python 2.3, see PEP 291. #
######################################################################
import sys
from ctypes import *
_array_type = type(c_int * 3)
def _other_endian(typ):
"""Return the type with the 'other' byte order. Simple types like
c_int and so on already have __ctype_be__ and __ctype_le__
attributes which contain the types, for more complicated types
only arrays are supported.
"""
try:
return getattr(typ, _OTHER_ENDIAN)
except AttributeError:
if type(typ) == _array_type:
return _other_endian(typ._type_) * typ._length_
raise TypeError("This type does not support other endian: %s" % typ)
class _swapped_meta(type(Structure)):
def __setattr__(self, attrname, value):
if attrname == "_fields_":
fields = []
for desc in value:
name = desc[0]
typ = desc[1]
rest = desc[2:]
fields.append((name, _other_endian(typ)) + rest)
value = fields
super(_swapped_meta, self).__setattr__(attrname, value)
################################################################
# Note: The Structure metaclass checks for the *presence* (not the
# value!) of a _swapped_bytes_ attribute to determine the bit order in
# structures containing bit fields.
if sys.byteorder == "little":
_OTHER_ENDIAN = "__ctype_be__"
LittleEndianStructure = Structure
class BigEndianStructure(Structure):
"""Structure with big endian byte order"""
__metaclass__ = _swapped_meta
_swappedbytes_ = None
elif sys.byteorder == "big":
_OTHER_ENDIAN = "__ctype_le__"
BigEndianStructure = Structure
class LittleEndianStructure(Structure):
"""Structure with little endian byte order"""
__metaclass__ = _swapped_meta
_swappedbytes_ = None
else:
raise RuntimeError("Invalid byteorder")

View File

@@ -0,0 +1,7 @@
Files in this directory from from Bob Ippolito's py2app.
License: Any components of the py2app suite may be distributed under
the MIT or PSF open source licenses.
This is version 1.0, SVN revision 789, from 2006/01/25.
The main repository is http://svn.red-bean.com/bob/macholib/trunk/macholib/

View File

@@ -0,0 +1,12 @@
######################################################################
# This file should be kept compatible with Python 2.3, see PEP 291. #
######################################################################
"""
Enough Mach-O to make your head spin.
See the relevant header files in /usr/include/mach-o
And also Apple's documentation.
"""
__version__ = '1.0'

View File

@@ -0,0 +1,169 @@
######################################################################
# This file should be kept compatible with Python 2.3, see PEP 291. #
######################################################################
"""
dyld emulation
"""
import os
from framework import framework_info
from dylib import dylib_info
from itertools import *
__all__ = [
'dyld_find', 'framework_find',
'framework_info', 'dylib_info',
]
# These are the defaults as per man dyld(1)
#
DEFAULT_FRAMEWORK_FALLBACK = [
os.path.expanduser("~/Library/Frameworks"),
"/Library/Frameworks",
"/Network/Library/Frameworks",
"/System/Library/Frameworks",
]
DEFAULT_LIBRARY_FALLBACK = [
os.path.expanduser("~/lib"),
"/usr/local/lib",
"/lib",
"/usr/lib",
]
def ensure_utf8(s):
"""Not all of PyObjC and Python understand unicode paths very well yet"""
if isinstance(s, unicode):
return s.encode('utf8')
return s
def dyld_env(env, var):
if env is None:
env = os.environ
rval = env.get(var)
if rval is None:
return []
return rval.split(':')
def dyld_image_suffix(env=None):
if env is None:
env = os.environ
return env.get('DYLD_IMAGE_SUFFIX')
def dyld_framework_path(env=None):
return dyld_env(env, 'DYLD_FRAMEWORK_PATH')
def dyld_library_path(env=None):
return dyld_env(env, 'DYLD_LIBRARY_PATH')
def dyld_fallback_framework_path(env=None):
return dyld_env(env, 'DYLD_FALLBACK_FRAMEWORK_PATH')
def dyld_fallback_library_path(env=None):
return dyld_env(env, 'DYLD_FALLBACK_LIBRARY_PATH')
def dyld_image_suffix_search(iterator, env=None):
"""For a potential path iterator, add DYLD_IMAGE_SUFFIX semantics"""
suffix = dyld_image_suffix(env)
if suffix is None:
return iterator
def _inject(iterator=iterator, suffix=suffix):
for path in iterator:
if path.endswith('.dylib'):
yield path[:-len('.dylib')] + suffix + '.dylib'
else:
yield path + suffix
yield path
return _inject()
def dyld_override_search(name, env=None):
# If DYLD_FRAMEWORK_PATH is set and this dylib_name is a
# framework name, use the first file that exists in the framework
# path if any. If there is none go on to search the DYLD_LIBRARY_PATH
# if any.
framework = framework_info(name)
if framework is not None:
for path in dyld_framework_path(env):
yield os.path.join(path, framework['name'])
# If DYLD_LIBRARY_PATH is set then use the first file that exists
# in the path. If none use the original name.
for path in dyld_library_path(env):
yield os.path.join(path, os.path.basename(name))
def dyld_executable_path_search(name, executable_path=None):
# If we haven't done any searching and found a library and the
# dylib_name starts with "@executable_path/" then construct the
# library name.
if name.startswith('@executable_path/') and executable_path is not None:
yield os.path.join(executable_path, name[len('@executable_path/'):])
def dyld_default_search(name, env=None):
yield name
framework = framework_info(name)
if framework is not None:
fallback_framework_path = dyld_fallback_framework_path(env)
for path in fallback_framework_path:
yield os.path.join(path, framework['name'])
fallback_library_path = dyld_fallback_library_path(env)
for path in fallback_library_path:
yield os.path.join(path, os.path.basename(name))
if framework is not None and not fallback_framework_path:
for path in DEFAULT_FRAMEWORK_FALLBACK:
yield os.path.join(path, framework['name'])
if not fallback_library_path:
for path in DEFAULT_LIBRARY_FALLBACK:
yield os.path.join(path, os.path.basename(name))
def dyld_find(name, executable_path=None, env=None):
"""
Find a library or framework using dyld semantics
"""
name = ensure_utf8(name)
executable_path = ensure_utf8(executable_path)
for path in dyld_image_suffix_search(chain(
dyld_override_search(name, env),
dyld_executable_path_search(name, executable_path),
dyld_default_search(name, env),
), env):
if os.path.isfile(path):
return path
raise ValueError("dylib %s could not be found" % (name,))
def framework_find(fn, executable_path=None, env=None):
"""
Find a framework using dyld semantics in a very loose manner.
Will take input such as:
Python
Python.framework
Python.framework/Versions/Current
"""
try:
return dyld_find(fn, executable_path=executable_path, env=env)
except ValueError, e:
pass
fmwk_index = fn.rfind('.framework')
if fmwk_index == -1:
fmwk_index = len(fn)
fn += '.framework'
fn = os.path.join(fn, os.path.basename(fn[:fmwk_index]))
try:
return dyld_find(fn, executable_path=executable_path, env=env)
except ValueError:
raise e
def test_dyld_find():
env = {}
assert dyld_find('libSystem.dylib') == '/usr/lib/libSystem.dylib'
assert dyld_find('System.framework/System') == '/System/Library/Frameworks/System.framework/System'
if __name__ == '__main__':
test_dyld_find()

View File

@@ -0,0 +1,66 @@
######################################################################
# This file should be kept compatible with Python 2.3, see PEP 291. #
######################################################################
"""
Generic dylib path manipulation
"""
import re
__all__ = ['dylib_info']
DYLIB_RE = re.compile(r"""(?x)
(?P<location>^.*)(?:^|/)
(?P<name>
(?P<shortname>\w+?)
(?:\.(?P<version>[^._]+))?
(?:_(?P<suffix>[^._]+))?
\.dylib$
)
""")
def dylib_info(filename):
"""
A dylib name can take one of the following four forms:
Location/Name.SomeVersion_Suffix.dylib
Location/Name.SomeVersion.dylib
Location/Name_Suffix.dylib
Location/Name.dylib
returns None if not found or a mapping equivalent to:
dict(
location='Location',
name='Name.SomeVersion_Suffix.dylib',
shortname='Name',
version='SomeVersion',
suffix='Suffix',
)
Note that SomeVersion and Suffix are optional and may be None
if not present.
"""
is_dylib = DYLIB_RE.match(filename)
if not is_dylib:
return None
return is_dylib.groupdict()
def test_dylib_info():
def d(location=None, name=None, shortname=None, version=None, suffix=None):
return dict(
location=location,
name=name,
shortname=shortname,
version=version,
suffix=suffix
)
assert dylib_info('completely/invalid') is None
assert dylib_info('completely/invalide_debug') is None
assert dylib_info('P/Foo.dylib') == d('P', 'Foo.dylib', 'Foo')
assert dylib_info('P/Foo_debug.dylib') == d('P', 'Foo_debug.dylib', 'Foo', suffix='debug')
assert dylib_info('P/Foo.A.dylib') == d('P', 'Foo.A.dylib', 'Foo', 'A')
assert dylib_info('P/Foo_debug.A.dylib') == d('P', 'Foo_debug.A.dylib', 'Foo_debug', 'A')
assert dylib_info('P/Foo.A_debug.dylib') == d('P', 'Foo.A_debug.dylib', 'Foo', 'A', 'debug')
if __name__ == '__main__':
test_dylib_info()

View File

@@ -0,0 +1,2 @@
#!/bin/sh
svn export --force http://svn.red-bean.com/bob/macholib/trunk/macholib/ .

View File

@@ -0,0 +1 @@
svn export --force http://svn.red-bean.com/bob/macholib/trunk/macholib/ .

View File

@@ -0,0 +1,68 @@
######################################################################
# This file should be kept compatible with Python 2.3, see PEP 291. #
######################################################################
"""
Generic framework path manipulation
"""
import re
__all__ = ['framework_info']
STRICT_FRAMEWORK_RE = re.compile(r"""(?x)
(?P<location>^.*)(?:^|/)
(?P<name>
(?P<shortname>\w+).framework/
(?:Versions/(?P<version>[^/]+)/)?
(?P=shortname)
(?:_(?P<suffix>[^_]+))?
)$
""")
def framework_info(filename):
"""
A framework name can take one of the following four forms:
Location/Name.framework/Versions/SomeVersion/Name_Suffix
Location/Name.framework/Versions/SomeVersion/Name
Location/Name.framework/Name_Suffix
Location/Name.framework/Name
returns None if not found, or a mapping equivalent to:
dict(
location='Location',
name='Name.framework/Versions/SomeVersion/Name_Suffix',
shortname='Name',
version='SomeVersion',
suffix='Suffix',
)
Note that SomeVersion and Suffix are optional and may be None
if not present
"""
is_framework = STRICT_FRAMEWORK_RE.match(filename)
if not is_framework:
return None
return is_framework.groupdict()
def test_framework_info():
def d(location=None, name=None, shortname=None, version=None, suffix=None):
return dict(
location=location,
name=name,
shortname=shortname,
version=version,
suffix=suffix
)
assert framework_info('completely/invalid') is None
assert framework_info('completely/invalid/_debug') is None
assert framework_info('P/F.framework') is None
assert framework_info('P/F.framework/_debug') is None
assert framework_info('P/F.framework/F') == d('P', 'F.framework/F', 'F')
assert framework_info('P/F.framework/F_debug') == d('P', 'F.framework/F_debug', 'F', suffix='debug')
assert framework_info('P/F.framework/Versions') is None
assert framework_info('P/F.framework/Versions/A') is None
assert framework_info('P/F.framework/Versions/A/F') == d('P', 'F.framework/Versions/A/F', 'F', 'A')
assert framework_info('P/F.framework/Versions/A/F_debug') == d('P', 'F.framework/Versions/A/F_debug', 'F', 'A', 'debug')
if __name__ == '__main__':
test_framework_info()

View File

@@ -0,0 +1,208 @@
import os, sys, unittest, getopt, time
use_resources = []
class ResourceDenied(Exception):
"""Test skipped because it requested a disallowed resource.
This is raised when a test calls requires() for a resource that
has not be enabled. Resources are defined by test modules.
"""
def is_resource_enabled(resource):
"""Test whether a resource is enabled.
If the caller's module is __main__ then automatically return True."""
if sys._getframe().f_back.f_globals.get("__name__") == "__main__":
return True
result = use_resources is not None and \
(resource in use_resources or "*" in use_resources)
if not result:
_unavail[resource] = None
return result
_unavail = {}
def requires(resource, msg=None):
"""Raise ResourceDenied if the specified resource is not available.
If the caller's module is __main__ then automatically return True."""
# see if the caller's module is __main__ - if so, treat as if
# the resource was set
if sys._getframe().f_back.f_globals.get("__name__") == "__main__":
return
if not is_resource_enabled(resource):
if msg is None:
msg = "Use of the `%s' resource not enabled" % resource
raise ResourceDenied(msg)
def find_package_modules(package, mask):
import fnmatch
if (hasattr(package, "__loader__") and
hasattr(package.__loader__, '_files')):
path = package.__name__.replace(".", os.path.sep)
mask = os.path.join(path, mask)
for fnm in package.__loader__._files.iterkeys():
if fnmatch.fnmatchcase(fnm, mask):
yield os.path.splitext(fnm)[0].replace(os.path.sep, ".")
else:
path = package.__path__[0]
for fnm in os.listdir(path):
if fnmatch.fnmatchcase(fnm, mask):
yield "%s.%s" % (package.__name__, os.path.splitext(fnm)[0])
def get_tests(package, mask, verbosity, exclude=()):
"""Return a list of skipped test modules, and a list of test cases."""
tests = []
skipped = []
for modname in find_package_modules(package, mask):
if modname.split(".")[-1] in exclude:
skipped.append(modname)
if verbosity > 1:
print >> sys.stderr, "Skipped %s: excluded" % modname
continue
try:
mod = __import__(modname, globals(), locals(), ['*'])
except ResourceDenied, detail:
skipped.append(modname)
if verbosity > 1:
print >> sys.stderr, "Skipped %s: %s" % (modname, detail)
continue
for name in dir(mod):
if name.startswith("_"):
continue
o = getattr(mod, name)
if type(o) is type(unittest.TestCase) and issubclass(o, unittest.TestCase):
tests.append(o)
return skipped, tests
def usage():
print __doc__
return 1
def test_with_refcounts(runner, verbosity, testcase):
"""Run testcase several times, tracking reference counts."""
import gc
import ctypes
ptc = ctypes._pointer_type_cache.copy()
cfc = ctypes._c_functype_cache.copy()
wfc = ctypes._win_functype_cache.copy()
# when searching for refcount leaks, we have to manually reset any
# caches that ctypes has.
def cleanup():
ctypes._pointer_type_cache = ptc.copy()
ctypes._c_functype_cache = cfc.copy()
ctypes._win_functype_cache = wfc.copy()
gc.collect()
test = unittest.makeSuite(testcase)
for i in range(5):
rc = sys.gettotalrefcount()
runner.run(test)
cleanup()
COUNT = 5
refcounts = [None] * COUNT
for i in range(COUNT):
rc = sys.gettotalrefcount()
runner.run(test)
cleanup()
refcounts[i] = sys.gettotalrefcount() - rc
if filter(None, refcounts):
print "%s leaks:\n\t" % testcase, refcounts
elif verbosity:
print "%s: ok." % testcase
class TestRunner(unittest.TextTestRunner):
def run(self, test, skipped):
"Run the given test case or test suite."
# Same as unittest.TextTestRunner.run, except that it reports
# skipped tests.
result = self._makeResult()
startTime = time.time()
test(result)
stopTime = time.time()
timeTaken = stopTime - startTime
result.printErrors()
self.stream.writeln(result.separator2)
run = result.testsRun
if _unavail: #skipped:
requested = _unavail.keys()
requested.sort()
self.stream.writeln("Ran %d test%s in %.3fs (%s module%s skipped)" %
(run, run != 1 and "s" or "", timeTaken,
len(skipped),
len(skipped) != 1 and "s" or ""))
self.stream.writeln("Unavailable resources: %s" % ", ".join(requested))
else:
self.stream.writeln("Ran %d test%s in %.3fs" %
(run, run != 1 and "s" or "", timeTaken))
self.stream.writeln()
if not result.wasSuccessful():
self.stream.write("FAILED (")
failed, errored = map(len, (result.failures, result.errors))
if failed:
self.stream.write("failures=%d" % failed)
if errored:
if failed: self.stream.write(", ")
self.stream.write("errors=%d" % errored)
self.stream.writeln(")")
else:
self.stream.writeln("OK")
return result
def main(*packages):
try:
opts, args = getopt.getopt(sys.argv[1:], "rqvu:x:")
except getopt.error:
return usage()
verbosity = 1
search_leaks = False
exclude = []
for flag, value in opts:
if flag == "-q":
verbosity -= 1
elif flag == "-v":
verbosity += 1
elif flag == "-r":
try:
sys.gettotalrefcount
except AttributeError:
print >> sys.stderr, "-r flag requires Python debug build"
return -1
search_leaks = True
elif flag == "-u":
use_resources.extend(value.split(","))
elif flag == "-x":
exclude.extend(value.split(","))
mask = "test_*.py"
if args:
mask = args[0]
for package in packages:
run_tests(package, mask, verbosity, search_leaks, exclude)
def run_tests(package, mask, verbosity, search_leaks, exclude):
skipped, testcases = get_tests(package, mask, verbosity, exclude)
runner = TestRunner(verbosity=verbosity)
suites = [unittest.makeSuite(o) for o in testcases]
suite = unittest.TestSuite(suites)
result = runner.run(suite, skipped)
if search_leaks:
# hunt for refcount leaks
runner = BasicTestRunner()
for t in testcases:
test_with_refcounts(runner, verbosity, t)
return bool(result.errors)
class BasicTestRunner:
def run(self, test):
result = unittest.TestResult()
test(result)
return result

View File

@@ -0,0 +1,19 @@
"""Usage: runtests.py [-q] [-r] [-v] [-u resources] [mask]
Run all tests found in this directory, and print a summary of the results.
Command line flags:
-q quiet mode: don't prnt anything while the tests are running
-r run tests repeatedly, look for refcount leaks
-u<resources>
Add resources to the lits of allowed resources. '*' allows all
resources.
-v verbose mode: print the test currently executed
-x<test1[,test2...]>
Exclude specified tests.
mask mask to select filenames containing testcases, wildcards allowed
"""
import sys
import ctypes.test
if __name__ == "__main__":
sys.exit(ctypes.test.main(ctypes.test))

View File

@@ -0,0 +1,60 @@
import unittest
from ctypes import *
class AnonTest(unittest.TestCase):
def test_anon(self):
class ANON(Union):
_fields_ = [("a", c_int),
("b", c_int)]
class Y(Structure):
_fields_ = [("x", c_int),
("_", ANON),
("y", c_int)]
_anonymous_ = ["_"]
self.assertEqual(Y.a.offset, sizeof(c_int))
self.assertEqual(Y.b.offset, sizeof(c_int))
self.assertEqual(ANON.a.offset, 0)
self.assertEqual(ANON.b.offset, 0)
def test_anon_nonseq(self):
# TypeError: _anonymous_ must be a sequence
self.assertRaises(TypeError,
lambda: type(Structure)("Name",
(Structure,),
{"_fields_": [], "_anonymous_": 42}))
def test_anon_nonmember(self):
# AttributeError: type object 'Name' has no attribute 'x'
self.assertRaises(AttributeError,
lambda: type(Structure)("Name",
(Structure,),
{"_fields_": [],
"_anonymous_": ["x"]}))
def test_nested(self):
class ANON_S(Structure):
_fields_ = [("a", c_int)]
class ANON_U(Union):
_fields_ = [("_", ANON_S),
("b", c_int)]
_anonymous_ = ["_"]
class Y(Structure):
_fields_ = [("x", c_int),
("_", ANON_U),
("y", c_int)]
_anonymous_ = ["_"]
self.assertEqual(Y.x.offset, 0)
self.assertEqual(Y.a.offset, sizeof(c_int))
self.assertEqual(Y.b.offset, sizeof(c_int))
self.assertEqual(Y._.offset, sizeof(c_int))
self.assertEqual(Y.y.offset, sizeof(c_int) * 2)
if __name__ == "__main__":
unittest.main()

View File

@@ -0,0 +1,64 @@
import unittest
from ctypes import *
from binascii import hexlify
import re
def dump(obj):
# helper function to dump memory contents in hex, with a hyphen
# between the bytes.
h = hexlify(memoryview(obj))
return re.sub(r"(..)", r"\1-", h)[:-1]
class Value(Structure):
_fields_ = [("val", c_byte)]
class Container(Structure):
_fields_ = [("pvalues", POINTER(Value))]
class Test(unittest.TestCase):
def test(self):
# create an array of 4 values
val_array = (Value * 4)()
# create a container, which holds a pointer to the pvalues array.
c = Container()
c.pvalues = val_array
# memory contains 4 NUL bytes now, that's correct
self.assertEqual("00-00-00-00", dump(val_array))
# set the values of the array through the pointer:
for i in range(4):
c.pvalues[i].val = i + 1
values = [c.pvalues[i].val for i in range(4)]
# These are the expected results: here s the bug!
self.assertEqual(
(values, dump(val_array)),
([1, 2, 3, 4], "01-02-03-04")
)
def test_2(self):
val_array = (Value * 4)()
# memory contains 4 NUL bytes now, that's correct
self.assertEqual("00-00-00-00", dump(val_array))
ptr = cast(val_array, POINTER(Value))
# set the values of the array through the pointer:
for i in range(4):
ptr[i].val = i + 1
values = [ptr[i].val for i in range(4)]
# These are the expected results: here s the bug!
self.assertEqual(
(values, dump(val_array)),
([1, 2, 3, 4], "01-02-03-04")
)
if __name__ == "__main__":
unittest.main()

View File

@@ -0,0 +1,134 @@
import unittest
from ctypes import *
formats = "bBhHiIlLqQfd"
formats = c_byte, c_ubyte, c_short, c_ushort, c_int, c_uint, \
c_long, c_ulonglong, c_float, c_double, c_longdouble
class ArrayTestCase(unittest.TestCase):
def test_simple(self):
# create classes holding simple numeric types, and check
# various properties.
init = range(15, 25)
for fmt in formats:
alen = len(init)
int_array = ARRAY(fmt, alen)
ia = int_array(*init)
# length of instance ok?
self.assertEqual(len(ia), alen)
# slot values ok?
values = [ia[i] for i in range(len(init))]
self.assertEqual(values, init)
# change the items
from operator import setitem
new_values = range(42, 42+alen)
[setitem(ia, n, new_values[n]) for n in range(alen)]
values = [ia[i] for i in range(len(init))]
self.assertEqual(values, new_values)
# are the items initialized to 0?
ia = int_array()
values = [ia[i] for i in range(len(init))]
self.assertEqual(values, [0] * len(init))
# Too many in itializers should be caught
self.assertRaises(IndexError, int_array, *range(alen*2))
CharArray = ARRAY(c_char, 3)
ca = CharArray("a", "b", "c")
# Should this work? It doesn't:
# CharArray("abc")
self.assertRaises(TypeError, CharArray, "abc")
self.assertEqual(ca[0], "a")
self.assertEqual(ca[1], "b")
self.assertEqual(ca[2], "c")
self.assertEqual(ca[-3], "a")
self.assertEqual(ca[-2], "b")
self.assertEqual(ca[-1], "c")
self.assertEqual(len(ca), 3)
# slicing is now supported, but not extended slicing (3-argument)!
from operator import getslice, delitem
self.assertRaises(TypeError, getslice, ca, 0, 1, -1)
# cannot delete items
self.assertRaises(TypeError, delitem, ca, 0)
def test_numeric_arrays(self):
alen = 5
numarray = ARRAY(c_int, alen)
na = numarray()
values = [na[i] for i in range(alen)]
self.assertEqual(values, [0] * alen)
na = numarray(*[c_int()] * alen)
values = [na[i] for i in range(alen)]
self.assertEqual(values, [0]*alen)
na = numarray(1, 2, 3, 4, 5)
values = [i for i in na]
self.assertEqual(values, [1, 2, 3, 4, 5])
na = numarray(*map(c_int, (1, 2, 3, 4, 5)))
values = [i for i in na]
self.assertEqual(values, [1, 2, 3, 4, 5])
def test_classcache(self):
self.assertTrue(not ARRAY(c_int, 3) is ARRAY(c_int, 4))
self.assertTrue(ARRAY(c_int, 3) is ARRAY(c_int, 3))
def test_from_address(self):
# Failed with 0.9.8, reported by JUrner
p = create_string_buffer("foo")
sz = (c_char * 3).from_address(addressof(p))
self.assertEqual(sz[:], "foo")
self.assertEqual(sz[::], "foo")
self.assertEqual(sz[::-1], "oof")
self.assertEqual(sz[::3], "f")
self.assertEqual(sz[1:4:2], "o")
self.assertEqual(sz.value, "foo")
try:
create_unicode_buffer
except NameError:
pass
else:
def test_from_addressW(self):
p = create_unicode_buffer("foo")
sz = (c_wchar * 3).from_address(addressof(p))
self.assertEqual(sz[:], "foo")
self.assertEqual(sz[::], "foo")
self.assertEqual(sz[::-1], "oof")
self.assertEqual(sz[::3], "f")
self.assertEqual(sz[1:4:2], "o")
self.assertEqual(sz.value, "foo")
def test_cache(self):
# Array types are cached internally in the _ctypes extension,
# in a WeakValueDictionary. Make sure the array type is
# removed from the cache when the itemtype goes away. This
# test will not fail, but will show a leak in the testsuite.
# Create a new type:
class my_int(c_int):
pass
# Create a new array type based on it:
t1 = my_int * 1
t2 = my_int * 1
self.assertTrue(t1 is t2)
if __name__ == '__main__':
unittest.main()

View File

@@ -0,0 +1,215 @@
import unittest
from ctypes import *
import _ctypes_test
dll = CDLL(_ctypes_test.__file__)
try:
CALLBACK_FUNCTYPE = WINFUNCTYPE
except NameError:
# fake to enable this test on Linux
CALLBACK_FUNCTYPE = CFUNCTYPE
class POINT(Structure):
_fields_ = [("x", c_int), ("y", c_int)]
class BasicWrapTestCase(unittest.TestCase):
def wrap(self, param):
return param
def test_wchar_parm(self):
try:
c_wchar
except NameError:
return
f = dll._testfunc_i_bhilfd
f.argtypes = [c_byte, c_wchar, c_int, c_long, c_float, c_double]
result = f(self.wrap(1), self.wrap(u"x"), self.wrap(3), self.wrap(4), self.wrap(5.0), self.wrap(6.0))
self.assertEqual(result, 139)
self.assertTrue(type(result), int)
def test_pointers(self):
f = dll._testfunc_p_p
f.restype = POINTER(c_int)
f.argtypes = [POINTER(c_int)]
# This only works if the value c_int(42) passed to the
# function is still alive while the pointer (the result) is
# used.
v = c_int(42)
self.assertEqual(pointer(v).contents.value, 42)
result = f(self.wrap(pointer(v)))
self.assertEqual(type(result), POINTER(c_int))
self.assertEqual(result.contents.value, 42)
# This on works...
result = f(self.wrap(pointer(v)))
self.assertEqual(result.contents.value, v.value)
p = pointer(c_int(99))
result = f(self.wrap(p))
self.assertEqual(result.contents.value, 99)
def test_shorts(self):
f = dll._testfunc_callback_i_if
args = []
expected = [262144, 131072, 65536, 32768, 16384, 8192, 4096, 2048,
1024, 512, 256, 128, 64, 32, 16, 8, 4, 2, 1]
def callback(v):
args.append(v)
return v
CallBack = CFUNCTYPE(c_int, c_int)
cb = CallBack(callback)
f(self.wrap(2**18), self.wrap(cb))
self.assertEqual(args, expected)
################################################################
def test_callbacks(self):
f = dll._testfunc_callback_i_if
f.restype = c_int
MyCallback = CFUNCTYPE(c_int, c_int)
def callback(value):
#print "called back with", value
return value
cb = MyCallback(callback)
result = f(self.wrap(-10), self.wrap(cb))
self.assertEqual(result, -18)
# test with prototype
f.argtypes = [c_int, MyCallback]
cb = MyCallback(callback)
result = f(self.wrap(-10), self.wrap(cb))
self.assertEqual(result, -18)
result = f(self.wrap(-10), self.wrap(cb))
self.assertEqual(result, -18)
AnotherCallback = CALLBACK_FUNCTYPE(c_int, c_int, c_int, c_int, c_int)
# check that the prototype works: we call f with wrong
# argument types
cb = AnotherCallback(callback)
self.assertRaises(ArgumentError, f, self.wrap(-10), self.wrap(cb))
def test_callbacks_2(self):
# Can also use simple datatypes as argument type specifiers
# for the callback function.
# In this case the call receives an instance of that type
f = dll._testfunc_callback_i_if
f.restype = c_int
MyCallback = CFUNCTYPE(c_int, c_int)
f.argtypes = [c_int, MyCallback]
def callback(value):
#print "called back with", value
self.assertEqual(type(value), int)
return value
cb = MyCallback(callback)
result = f(self.wrap(-10), self.wrap(cb))
self.assertEqual(result, -18)
def test_longlong_callbacks(self):
f = dll._testfunc_callback_q_qf
f.restype = c_longlong
MyCallback = CFUNCTYPE(c_longlong, c_longlong)
f.argtypes = [c_longlong, MyCallback]
def callback(value):
self.assertTrue(isinstance(value, (int, long)))
return value & 0x7FFFFFFF
cb = MyCallback(callback)
self.assertEqual(13577625587, int(f(self.wrap(1000000000000), self.wrap(cb))))
def test_byval(self):
# without prototype
ptin = POINT(1, 2)
ptout = POINT()
# EXPORT int _testfunc_byval(point in, point *pout)
result = dll._testfunc_byval(ptin, byref(ptout))
got = result, ptout.x, ptout.y
expected = 3, 1, 2
self.assertEqual(got, expected)
# with prototype
ptin = POINT(101, 102)
ptout = POINT()
dll._testfunc_byval.argtypes = (POINT, POINTER(POINT))
dll._testfunc_byval.restype = c_int
result = dll._testfunc_byval(self.wrap(ptin), byref(ptout))
got = result, ptout.x, ptout.y
expected = 203, 101, 102
self.assertEqual(got, expected)
def test_struct_return_2H(self):
class S2H(Structure):
_fields_ = [("x", c_short),
("y", c_short)]
dll.ret_2h_func.restype = S2H
dll.ret_2h_func.argtypes = [S2H]
inp = S2H(99, 88)
s2h = dll.ret_2h_func(self.wrap(inp))
self.assertEqual((s2h.x, s2h.y), (99*2, 88*3))
def test_struct_return_8H(self):
class S8I(Structure):
_fields_ = [("a", c_int),
("b", c_int),
("c", c_int),
("d", c_int),
("e", c_int),
("f", c_int),
("g", c_int),
("h", c_int)]
dll.ret_8i_func.restype = S8I
dll.ret_8i_func.argtypes = [S8I]
inp = S8I(9, 8, 7, 6, 5, 4, 3, 2)
s8i = dll.ret_8i_func(self.wrap(inp))
self.assertEqual((s8i.a, s8i.b, s8i.c, s8i.d, s8i.e, s8i.f, s8i.g, s8i.h),
(9*2, 8*3, 7*4, 6*5, 5*6, 4*7, 3*8, 2*9))
#~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
class AsParamWrapper(object):
def __init__(self, param):
self._as_parameter_ = param
class AsParamWrapperTestCase(BasicWrapTestCase):
wrap = AsParamWrapper
#~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
class AsParamPropertyWrapper(object):
def __init__(self, param):
self._param = param
def getParameter(self):
return self._param
_as_parameter_ = property(getParameter)
class AsParamPropertyWrapperTestCase(BasicWrapTestCase):
wrap = AsParamPropertyWrapper
#~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
if __name__ == '__main__':
unittest.main()

View File

@@ -0,0 +1,244 @@
from ctypes import *
import unittest
import os
import ctypes
import _ctypes_test
class BITS(Structure):
_fields_ = [("A", c_int, 1),
("B", c_int, 2),
("C", c_int, 3),
("D", c_int, 4),
("E", c_int, 5),
("F", c_int, 6),
("G", c_int, 7),
("H", c_int, 8),
("I", c_int, 9),
("M", c_short, 1),
("N", c_short, 2),
("O", c_short, 3),
("P", c_short, 4),
("Q", c_short, 5),
("R", c_short, 6),
("S", c_short, 7)]
func = CDLL(_ctypes_test.__file__).unpack_bitfields
func.argtypes = POINTER(BITS), c_char
##for n in "ABCDEFGHIMNOPQRS":
## print n, hex(getattr(BITS, n).size), getattr(BITS, n).offset
class C_Test(unittest.TestCase):
def test_ints(self):
for i in range(512):
for name in "ABCDEFGHI":
b = BITS()
setattr(b, name, i)
self.assertEqual((name, i, getattr(b, name)), (name, i, func(byref(b), name)))
def test_shorts(self):
for i in range(256):
for name in "MNOPQRS":
b = BITS()
setattr(b, name, i)
self.assertEqual((name, i, getattr(b, name)), (name, i, func(byref(b), name)))
signed_int_types = (c_byte, c_short, c_int, c_long, c_longlong)
unsigned_int_types = (c_ubyte, c_ushort, c_uint, c_ulong, c_ulonglong)
int_types = unsigned_int_types + signed_int_types
class BitFieldTest(unittest.TestCase):
def test_longlong(self):
class X(Structure):
_fields_ = [("a", c_longlong, 1),
("b", c_longlong, 62),
("c", c_longlong, 1)]
self.assertEqual(sizeof(X), sizeof(c_longlong))
x = X()
x.a, x.b, x.c = -1, 7, -1
self.assertEqual((x.a, x.b, x.c), (-1, 7, -1))
def test_ulonglong(self):
class X(Structure):
_fields_ = [("a", c_ulonglong, 1),
("b", c_ulonglong, 62),
("c", c_ulonglong, 1)]
self.assertEqual(sizeof(X), sizeof(c_longlong))
x = X()
self.assertEqual((x.a, x.b, x.c), (0, 0, 0))
x.a, x.b, x.c = 7, 7, 7
self.assertEqual((x.a, x.b, x.c), (1, 7, 1))
def test_signed(self):
for c_typ in signed_int_types:
class X(Structure):
_fields_ = [("dummy", c_typ),
("a", c_typ, 3),
("b", c_typ, 3),
("c", c_typ, 1)]
self.assertEqual(sizeof(X), sizeof(c_typ)*2)
x = X()
self.assertEqual((c_typ, x.a, x.b, x.c), (c_typ, 0, 0, 0))
x.a = -1
self.assertEqual((c_typ, x.a, x.b, x.c), (c_typ, -1, 0, 0))
x.a, x.b = 0, -1
self.assertEqual((c_typ, x.a, x.b, x.c), (c_typ, 0, -1, 0))
def test_unsigned(self):
for c_typ in unsigned_int_types:
class X(Structure):
_fields_ = [("a", c_typ, 3),
("b", c_typ, 3),
("c", c_typ, 1)]
self.assertEqual(sizeof(X), sizeof(c_typ))
x = X()
self.assertEqual((c_typ, x.a, x.b, x.c), (c_typ, 0, 0, 0))
x.a = -1
self.assertEqual((c_typ, x.a, x.b, x.c), (c_typ, 7, 0, 0))
x.a, x.b = 0, -1
self.assertEqual((c_typ, x.a, x.b, x.c), (c_typ, 0, 7, 0))
def fail_fields(self, *fields):
return self.get_except(type(Structure), "X", (),
{"_fields_": fields})
def test_nonint_types(self):
# bit fields are not allowed on non-integer types.
result = self.fail_fields(("a", c_char_p, 1))
self.assertEqual(result, (TypeError, 'bit fields not allowed for type c_char_p'))
result = self.fail_fields(("a", c_void_p, 1))
self.assertEqual(result, (TypeError, 'bit fields not allowed for type c_void_p'))
if c_int != c_long:
result = self.fail_fields(("a", POINTER(c_int), 1))
self.assertEqual(result, (TypeError, 'bit fields not allowed for type LP_c_int'))
result = self.fail_fields(("a", c_char, 1))
self.assertEqual(result, (TypeError, 'bit fields not allowed for type c_char'))
try:
c_wchar
except NameError:
pass
else:
result = self.fail_fields(("a", c_wchar, 1))
self.assertEqual(result, (TypeError, 'bit fields not allowed for type c_wchar'))
class Dummy(Structure):
_fields_ = []
result = self.fail_fields(("a", Dummy, 1))
self.assertEqual(result, (TypeError, 'bit fields not allowed for type Dummy'))
def test_single_bitfield_size(self):
for c_typ in int_types:
result = self.fail_fields(("a", c_typ, -1))
self.assertEqual(result, (ValueError, 'number of bits invalid for bit field'))
result = self.fail_fields(("a", c_typ, 0))
self.assertEqual(result, (ValueError, 'number of bits invalid for bit field'))
class X(Structure):
_fields_ = [("a", c_typ, 1)]
self.assertEqual(sizeof(X), sizeof(c_typ))
class X(Structure):
_fields_ = [("a", c_typ, sizeof(c_typ)*8)]
self.assertEqual(sizeof(X), sizeof(c_typ))
result = self.fail_fields(("a", c_typ, sizeof(c_typ)*8 + 1))
self.assertEqual(result, (ValueError, 'number of bits invalid for bit field'))
def test_multi_bitfields_size(self):
class X(Structure):
_fields_ = [("a", c_short, 1),
("b", c_short, 14),
("c", c_short, 1)]
self.assertEqual(sizeof(X), sizeof(c_short))
class X(Structure):
_fields_ = [("a", c_short, 1),
("a1", c_short),
("b", c_short, 14),
("c", c_short, 1)]
self.assertEqual(sizeof(X), sizeof(c_short)*3)
self.assertEqual(X.a.offset, 0)
self.assertEqual(X.a1.offset, sizeof(c_short))
self.assertEqual(X.b.offset, sizeof(c_short)*2)
self.assertEqual(X.c.offset, sizeof(c_short)*2)
class X(Structure):
_fields_ = [("a", c_short, 3),
("b", c_short, 14),
("c", c_short, 14)]
self.assertEqual(sizeof(X), sizeof(c_short)*3)
self.assertEqual(X.a.offset, sizeof(c_short)*0)
self.assertEqual(X.b.offset, sizeof(c_short)*1)
self.assertEqual(X.c.offset, sizeof(c_short)*2)
def get_except(self, func, *args, **kw):
try:
func(*args, **kw)
except Exception, detail:
return detail.__class__, str(detail)
def test_mixed_1(self):
class X(Structure):
_fields_ = [("a", c_byte, 4),
("b", c_int, 4)]
if os.name in ("nt", "ce"):
self.assertEqual(sizeof(X), sizeof(c_int)*2)
else:
self.assertEqual(sizeof(X), sizeof(c_int))
def test_mixed_2(self):
class X(Structure):
_fields_ = [("a", c_byte, 4),
("b", c_int, 32)]
self.assertEqual(sizeof(X), sizeof(c_int)*2)
def test_mixed_3(self):
class X(Structure):
_fields_ = [("a", c_byte, 4),
("b", c_ubyte, 4)]
self.assertEqual(sizeof(X), sizeof(c_byte))
def test_mixed_4(self):
class X(Structure):
_fields_ = [("a", c_short, 4),
("b", c_short, 4),
("c", c_int, 24),
("d", c_short, 4),
("e", c_short, 4),
("f", c_int, 24)]
# MSVC does NOT combine c_short and c_int into one field, GCC
# does (unless GCC is run with '-mms-bitfields' which
# produces code compatible with MSVC).
if os.name in ("nt", "ce"):
self.assertEqual(sizeof(X), sizeof(c_int) * 4)
else:
self.assertEqual(sizeof(X), sizeof(c_int) * 2)
def test_anon_bitfields(self):
# anonymous bit-fields gave a strange error message
class X(Structure):
_fields_ = [("a", c_byte, 4),
("b", c_ubyte, 4)]
class Y(Structure):
_anonymous_ = ["_"]
_fields_ = [("_", X)]
if __name__ == "__main__":
unittest.main()

View File

@@ -0,0 +1,70 @@
from ctypes import *
import unittest
class StringBufferTestCase(unittest.TestCase):
def test_buffer(self):
b = create_string_buffer(32)
self.assertEqual(len(b), 32)
self.assertEqual(sizeof(b), 32 * sizeof(c_char))
self.assertTrue(type(b[0]) is str)
b = create_string_buffer("abc")
self.assertEqual(len(b), 4) # trailing nul char
self.assertEqual(sizeof(b), 4 * sizeof(c_char))
self.assertTrue(type(b[0]) is str)
self.assertEqual(b[0], "a")
self.assertEqual(b[:], "abc\0")
self.assertEqual(b[::], "abc\0")
self.assertEqual(b[::-1], "\0cba")
self.assertEqual(b[::2], "ac")
self.assertEqual(b[::5], "a")
def test_string_conversion(self):
b = create_string_buffer(u"abc")
self.assertEqual(len(b), 4) # trailing nul char
self.assertEqual(sizeof(b), 4 * sizeof(c_char))
self.assertTrue(type(b[0]) is str)
self.assertEqual(b[0], "a")
self.assertEqual(b[:], "abc\0")
self.assertEqual(b[::], "abc\0")
self.assertEqual(b[::-1], "\0cba")
self.assertEqual(b[::2], "ac")
self.assertEqual(b[::5], "a")
try:
c_wchar
except NameError:
pass
else:
def test_unicode_buffer(self):
b = create_unicode_buffer(32)
self.assertEqual(len(b), 32)
self.assertEqual(sizeof(b), 32 * sizeof(c_wchar))
self.assertTrue(type(b[0]) is unicode)
b = create_unicode_buffer(u"abc")
self.assertEqual(len(b), 4) # trailing nul char
self.assertEqual(sizeof(b), 4 * sizeof(c_wchar))
self.assertTrue(type(b[0]) is unicode)
self.assertEqual(b[0], u"a")
self.assertEqual(b[:], "abc\0")
self.assertEqual(b[::], "abc\0")
self.assertEqual(b[::-1], "\0cba")
self.assertEqual(b[::2], "ac")
self.assertEqual(b[::5], "a")
def test_unicode_conversion(self):
b = create_unicode_buffer("abc")
self.assertEqual(len(b), 4) # trailing nul char
self.assertEqual(sizeof(b), 4 * sizeof(c_wchar))
self.assertTrue(type(b[0]) is unicode)
self.assertEqual(b[0], u"a")
self.assertEqual(b[:], "abc\0")
self.assertEqual(b[::], "abc\0")
self.assertEqual(b[::-1], "\0cba")
self.assertEqual(b[::2], "ac")
self.assertEqual(b[::5], "a")
if __name__ == "__main__":
unittest.main()

View File

@@ -0,0 +1,280 @@
import sys, unittest, struct, math
from binascii import hexlify
from ctypes import *
def bin(s):
return hexlify(memoryview(s)).upper()
# Each *simple* type that supports different byte orders has an
# __ctype_be__ attribute that specifies the same type in BIG ENDIAN
# byte order, and a __ctype_le__ attribute that is the same type in
# LITTLE ENDIAN byte order.
#
# For Structures and Unions, these types are created on demand.
class Test(unittest.TestCase):
def X_test(self):
print >> sys.stderr, sys.byteorder
for i in range(32):
bits = BITS()
setattr(bits, "i%s" % i, 1)
dump(bits)
def test_endian_short(self):
if sys.byteorder == "little":
self.assertTrue(c_short.__ctype_le__ is c_short)
self.assertTrue(c_short.__ctype_be__.__ctype_le__ is c_short)
else:
self.assertTrue(c_short.__ctype_be__ is c_short)
self.assertTrue(c_short.__ctype_le__.__ctype_be__ is c_short)
s = c_short.__ctype_be__(0x1234)
self.assertEqual(bin(struct.pack(">h", 0x1234)), "1234")
self.assertEqual(bin(s), "1234")
self.assertEqual(s.value, 0x1234)
s = c_short.__ctype_le__(0x1234)
self.assertEqual(bin(struct.pack("<h", 0x1234)), "3412")
self.assertEqual(bin(s), "3412")
self.assertEqual(s.value, 0x1234)
s = c_ushort.__ctype_be__(0x1234)
self.assertEqual(bin(struct.pack(">h", 0x1234)), "1234")
self.assertEqual(bin(s), "1234")
self.assertEqual(s.value, 0x1234)
s = c_ushort.__ctype_le__(0x1234)
self.assertEqual(bin(struct.pack("<h", 0x1234)), "3412")
self.assertEqual(bin(s), "3412")
self.assertEqual(s.value, 0x1234)
def test_endian_int(self):
if sys.byteorder == "little":
self.assertTrue(c_int.__ctype_le__ is c_int)
self.assertTrue(c_int.__ctype_be__.__ctype_le__ is c_int)
else:
self.assertTrue(c_int.__ctype_be__ is c_int)
self.assertTrue(c_int.__ctype_le__.__ctype_be__ is c_int)
s = c_int.__ctype_be__(0x12345678)
self.assertEqual(bin(struct.pack(">i", 0x12345678)), "12345678")
self.assertEqual(bin(s), "12345678")
self.assertEqual(s.value, 0x12345678)
s = c_int.__ctype_le__(0x12345678)
self.assertEqual(bin(struct.pack("<i", 0x12345678)), "78563412")
self.assertEqual(bin(s), "78563412")
self.assertEqual(s.value, 0x12345678)
s = c_uint.__ctype_be__(0x12345678)
self.assertEqual(bin(struct.pack(">I", 0x12345678)), "12345678")
self.assertEqual(bin(s), "12345678")
self.assertEqual(s.value, 0x12345678)
s = c_uint.__ctype_le__(0x12345678)
self.assertEqual(bin(struct.pack("<I", 0x12345678)), "78563412")
self.assertEqual(bin(s), "78563412")
self.assertEqual(s.value, 0x12345678)
def test_endian_longlong(self):
if sys.byteorder == "little":
self.assertTrue(c_longlong.__ctype_le__ is c_longlong)
self.assertTrue(c_longlong.__ctype_be__.__ctype_le__ is c_longlong)
else:
self.assertTrue(c_longlong.__ctype_be__ is c_longlong)
self.assertTrue(c_longlong.__ctype_le__.__ctype_be__ is c_longlong)
s = c_longlong.__ctype_be__(0x1234567890ABCDEF)
self.assertEqual(bin(struct.pack(">q", 0x1234567890ABCDEF)), "1234567890ABCDEF")
self.assertEqual(bin(s), "1234567890ABCDEF")
self.assertEqual(s.value, 0x1234567890ABCDEF)
s = c_longlong.__ctype_le__(0x1234567890ABCDEF)
self.assertEqual(bin(struct.pack("<q", 0x1234567890ABCDEF)), "EFCDAB9078563412")
self.assertEqual(bin(s), "EFCDAB9078563412")
self.assertEqual(s.value, 0x1234567890ABCDEF)
s = c_ulonglong.__ctype_be__(0x1234567890ABCDEF)
self.assertEqual(bin(struct.pack(">Q", 0x1234567890ABCDEF)), "1234567890ABCDEF")
self.assertEqual(bin(s), "1234567890ABCDEF")
self.assertEqual(s.value, 0x1234567890ABCDEF)
s = c_ulonglong.__ctype_le__(0x1234567890ABCDEF)
self.assertEqual(bin(struct.pack("<Q", 0x1234567890ABCDEF)), "EFCDAB9078563412")
self.assertEqual(bin(s), "EFCDAB9078563412")
self.assertEqual(s.value, 0x1234567890ABCDEF)
def test_endian_float(self):
if sys.byteorder == "little":
self.assertTrue(c_float.__ctype_le__ is c_float)
self.assertTrue(c_float.__ctype_be__.__ctype_le__ is c_float)
else:
self.assertTrue(c_float.__ctype_be__ is c_float)
self.assertTrue(c_float.__ctype_le__.__ctype_be__ is c_float)
s = c_float(math.pi)
self.assertEqual(bin(struct.pack("f", math.pi)), bin(s))
# Hm, what's the precision of a float compared to a double?
self.assertAlmostEqual(s.value, math.pi, 6)
s = c_float.__ctype_le__(math.pi)
self.assertAlmostEqual(s.value, math.pi, 6)
self.assertEqual(bin(struct.pack("<f", math.pi)), bin(s))
s = c_float.__ctype_be__(math.pi)
self.assertAlmostEqual(s.value, math.pi, 6)
self.assertEqual(bin(struct.pack(">f", math.pi)), bin(s))
def test_endian_double(self):
if sys.byteorder == "little":
self.assertTrue(c_double.__ctype_le__ is c_double)
self.assertTrue(c_double.__ctype_be__.__ctype_le__ is c_double)
else:
self.assertTrue(c_double.__ctype_be__ is c_double)
self.assertTrue(c_double.__ctype_le__.__ctype_be__ is c_double)
s = c_double(math.pi)
self.assertEqual(s.value, math.pi)
self.assertEqual(bin(struct.pack("d", math.pi)), bin(s))
s = c_double.__ctype_le__(math.pi)
self.assertEqual(s.value, math.pi)
self.assertEqual(bin(struct.pack("<d", math.pi)), bin(s))
s = c_double.__ctype_be__(math.pi)
self.assertEqual(s.value, math.pi)
self.assertEqual(bin(struct.pack(">d", math.pi)), bin(s))
def test_endian_other(self):
self.assertTrue(c_byte.__ctype_le__ is c_byte)
self.assertTrue(c_byte.__ctype_be__ is c_byte)
self.assertTrue(c_ubyte.__ctype_le__ is c_ubyte)
self.assertTrue(c_ubyte.__ctype_be__ is c_ubyte)
self.assertTrue(c_char.__ctype_le__ is c_char)
self.assertTrue(c_char.__ctype_be__ is c_char)
def test_struct_fields_1(self):
if sys.byteorder == "little":
base = BigEndianStructure
else:
base = LittleEndianStructure
class T(base):
pass
_fields_ = [("a", c_ubyte),
("b", c_byte),
("c", c_short),
("d", c_ushort),
("e", c_int),
("f", c_uint),
("g", c_long),
("h", c_ulong),
("i", c_longlong),
("k", c_ulonglong),
("l", c_float),
("m", c_double),
("n", c_char),
("b1", c_byte, 3),
("b2", c_byte, 3),
("b3", c_byte, 2),
("a", c_int * 3 * 3 * 3)]
T._fields_ = _fields_
# these fields do not support different byte order:
for typ in c_wchar, c_void_p, POINTER(c_int):
_fields_.append(("x", typ))
class T(base):
pass
self.assertRaises(TypeError, setattr, T, "_fields_", [("x", typ)])
def test_struct_struct(self):
# Nested structures with different byte order not (yet) supported
if sys.byteorder == "little":
base = BigEndianStructure
else:
base = LittleEndianStructure
class T(Structure):
_fields_ = [("a", c_int),
("b", c_int)]
class S(base):
pass
self.assertRaises(TypeError, setattr, S, "_fields_", [("s", T)])
def test_struct_fields_2(self):
# standard packing in struct uses no alignment.
# So, we have to align using pad bytes.
#
# Unaligned accesses will crash Python (on those platforms that
# don't allow it, like sparc solaris).
if sys.byteorder == "little":
base = BigEndianStructure
fmt = ">bxhid"
else:
base = LittleEndianStructure
fmt = "<bxhid"
class S(base):
_fields_ = [("b", c_byte),
("h", c_short),
("i", c_int),
("d", c_double)]
s1 = S(0x12, 0x1234, 0x12345678, 3.14)
s2 = struct.pack(fmt, 0x12, 0x1234, 0x12345678, 3.14)
self.assertEqual(bin(s1), bin(s2))
def test_unaligned_nonnative_struct_fields(self):
if sys.byteorder == "little":
base = BigEndianStructure
fmt = ">b h xi xd"
else:
base = LittleEndianStructure
fmt = "<b h xi xd"
class S(base):
_pack_ = 1
_fields_ = [("b", c_byte),
("h", c_short),
("_1", c_byte),
("i", c_int),
("_2", c_byte),
("d", c_double)]
s1 = S()
s1.b = 0x12
s1.h = 0x1234
s1.i = 0x12345678
s1.d = 3.14
s2 = struct.pack(fmt, 0x12, 0x1234, 0x12345678, 3.14)
self.assertEqual(bin(s1), bin(s2))
def test_unaligned_native_struct_fields(self):
if sys.byteorder == "little":
fmt = "<b h xi xd"
else:
base = LittleEndianStructure
fmt = ">b h xi xd"
class S(Structure):
_pack_ = 1
_fields_ = [("b", c_byte),
("h", c_short),
("_1", c_byte),
("i", c_int),
("_2", c_byte),
("d", c_double)]
s1 = S()
s1.b = 0x12
s1.h = 0x1234
s1.i = 0x12345678
s1.d = 3.14
s2 = struct.pack(fmt, 0x12, 0x1234, 0x12345678, 3.14)
self.assertEqual(bin(s1), bin(s2))
if __name__ == "__main__":
unittest.main()

View File

@@ -0,0 +1,212 @@
import unittest
from ctypes import *
import _ctypes_test
class Callbacks(unittest.TestCase):
functype = CFUNCTYPE
## def tearDown(self):
## import gc
## gc.collect()
def callback(self, *args):
self.got_args = args
return args[-1]
def check_type(self, typ, arg):
PROTO = self.functype.im_func(typ, typ)
result = PROTO(self.callback)(arg)
if typ == c_float:
self.assertAlmostEqual(result, arg, places=5)
else:
self.assertEqual(self.got_args, (arg,))
self.assertEqual(result, arg)
PROTO = self.functype.im_func(typ, c_byte, typ)
result = PROTO(self.callback)(-3, arg)
if typ == c_float:
self.assertAlmostEqual(result, arg, places=5)
else:
self.assertEqual(self.got_args, (-3, arg))
self.assertEqual(result, arg)
################
def test_byte(self):
self.check_type(c_byte, 42)
self.check_type(c_byte, -42)
def test_ubyte(self):
self.check_type(c_ubyte, 42)
def test_short(self):
self.check_type(c_short, 42)
self.check_type(c_short, -42)
def test_ushort(self):
self.check_type(c_ushort, 42)
def test_int(self):
self.check_type(c_int, 42)
self.check_type(c_int, -42)
def test_uint(self):
self.check_type(c_uint, 42)
def test_long(self):
self.check_type(c_long, 42)
self.check_type(c_long, -42)
def test_ulong(self):
self.check_type(c_ulong, 42)
def test_longlong(self):
# test some 64-bit values, positive and negative
self.check_type(c_longlong, 5948291757245277467)
self.check_type(c_longlong, -5229388909784190580)
self.check_type(c_longlong, 42)
self.check_type(c_longlong, -42)
def test_ulonglong(self):
# test some 64-bit values, with and without msb set.
self.check_type(c_ulonglong, 10955412242170339782)
self.check_type(c_ulonglong, 3665885499841167458)
self.check_type(c_ulonglong, 42)
def test_float(self):
# only almost equal: double -> float -> double
import math
self.check_type(c_float, math.e)
self.check_type(c_float, -math.e)
def test_double(self):
self.check_type(c_double, 3.14)
self.check_type(c_double, -3.14)
def test_longdouble(self):
self.check_type(c_longdouble, 3.14)
self.check_type(c_longdouble, -3.14)
def test_char(self):
self.check_type(c_char, "x")
self.check_type(c_char, "a")
# disabled: would now (correctly) raise a RuntimeWarning about
# a memory leak. A callback function cannot return a non-integral
# C type without causing a memory leak.
## def test_char_p(self):
## self.check_type(c_char_p, "abc")
## self.check_type(c_char_p, "def")
def test_pyobject(self):
o = ()
from sys import getrefcount as grc
for o in (), [], object():
initial = grc(o)
# This call leaks a reference to 'o'...
self.check_type(py_object, o)
before = grc(o)
# ...but this call doesn't leak any more. Where is the refcount?
self.check_type(py_object, o)
after = grc(o)
self.assertEqual((after, o), (before, o))
def test_unsupported_restype_1(self):
# Only "fundamental" result types are supported for callback
# functions, the type must have a non-NULL stgdict->setfunc.
# POINTER(c_double), for example, is not supported.
prototype = self.functype.im_func(POINTER(c_double))
# The type is checked when the prototype is called
self.assertRaises(TypeError, prototype, lambda: None)
def test_unsupported_restype_2(self):
prototype = self.functype.im_func(object)
self.assertRaises(TypeError, prototype, lambda: None)
def test_issue_7959(self):
proto = self.functype.im_func(None)
class X(object):
def func(self): pass
def __init__(self):
self.v = proto(self.func)
import gc
for i in range(32):
X()
gc.collect()
live = [x for x in gc.get_objects()
if isinstance(x, X)]
self.assertEqual(len(live), 0)
try:
WINFUNCTYPE
except NameError:
pass
else:
class StdcallCallbacks(Callbacks):
functype = WINFUNCTYPE
################################################################
class SampleCallbacksTestCase(unittest.TestCase):
def test_integrate(self):
# Derived from some then non-working code, posted by David Foster
dll = CDLL(_ctypes_test.__file__)
# The function prototype called by 'integrate': double func(double);
CALLBACK = CFUNCTYPE(c_double, c_double)
# The integrate function itself, exposed from the _ctypes_test dll
integrate = dll.integrate
integrate.argtypes = (c_double, c_double, CALLBACK, c_long)
integrate.restype = c_double
def func(x):
return x**2
result = integrate(0.0, 1.0, CALLBACK(func), 10)
diff = abs(result - 1./3.)
self.assertLess(diff, 0.01, "%s not less than 0.01" % diff)
def test_issue_8959_a(self):
from ctypes.util import find_library
libc_path = find_library("c")
if not libc_path:
return # cannot test
libc = CDLL(libc_path)
@CFUNCTYPE(c_int, POINTER(c_int), POINTER(c_int))
def cmp_func(a, b):
return a[0] - b[0]
array = (c_int * 5)(5, 1, 99, 7, 33)
libc.qsort(array, len(array), sizeof(c_int), cmp_func)
self.assertEqual(array[:], [1, 5, 7, 33, 99])
try:
WINFUNCTYPE
except NameError:
pass
else:
def test_issue_8959_b(self):
from ctypes.wintypes import BOOL, HWND, LPARAM
global windowCount
windowCount = 0
@WINFUNCTYPE(BOOL, HWND, LPARAM)
def EnumWindowsCallbackFunc(hwnd, lParam):
global windowCount
windowCount += 1
return True #Allow windows to keep enumerating
windll.user32.EnumWindows(EnumWindowsCallbackFunc, 0)
################################################################
if __name__ == '__main__':
unittest.main()

View File

@@ -0,0 +1,89 @@
from ctypes import *
import unittest
import sys
class Test(unittest.TestCase):
def test_array2pointer(self):
array = (c_int * 3)(42, 17, 2)
# casting an array to a pointer works.
ptr = cast(array, POINTER(c_int))
self.assertEqual([ptr[i] for i in range(3)], [42, 17, 2])
if 2*sizeof(c_short) == sizeof(c_int):
ptr = cast(array, POINTER(c_short))
if sys.byteorder == "little":
self.assertEqual([ptr[i] for i in range(6)],
[42, 0, 17, 0, 2, 0])
else:
self.assertEqual([ptr[i] for i in range(6)],
[0, 42, 0, 17, 0, 2])
def test_address2pointer(self):
array = (c_int * 3)(42, 17, 2)
address = addressof(array)
ptr = cast(c_void_p(address), POINTER(c_int))
self.assertEqual([ptr[i] for i in range(3)], [42, 17, 2])
ptr = cast(address, POINTER(c_int))
self.assertEqual([ptr[i] for i in range(3)], [42, 17, 2])
def test_p2a_objects(self):
array = (c_char_p * 5)()
self.assertEqual(array._objects, None)
array[0] = "foo bar"
self.assertEqual(array._objects, {'0': "foo bar"})
p = cast(array, POINTER(c_char_p))
# array and p share a common _objects attribute
self.assertTrue(p._objects is array._objects)
self.assertEqual(array._objects, {'0': "foo bar", id(array): array})
p[0] = "spam spam"
self.assertEqual(p._objects, {'0': "spam spam", id(array): array})
self.assertTrue(array._objects is p._objects)
p[1] = "foo bar"
self.assertEqual(p._objects, {'1': 'foo bar', '0': "spam spam", id(array): array})
self.assertTrue(array._objects is p._objects)
def test_other(self):
p = cast((c_int * 4)(1, 2, 3, 4), POINTER(c_int))
self.assertEqual(p[:4], [1,2, 3, 4])
self.assertEqual(p[:4:], [1, 2, 3, 4])
self.assertEqual(p[3:-1:-1], [4, 3, 2, 1])
self.assertEqual(p[:4:3], [1, 4])
c_int()
self.assertEqual(p[:4], [1, 2, 3, 4])
self.assertEqual(p[:4:], [1, 2, 3, 4])
self.assertEqual(p[3:-1:-1], [4, 3, 2, 1])
self.assertEqual(p[:4:3], [1, 4])
p[2] = 96
self.assertEqual(p[:4], [1, 2, 96, 4])
self.assertEqual(p[:4:], [1, 2, 96, 4])
self.assertEqual(p[3:-1:-1], [4, 96, 2, 1])
self.assertEqual(p[:4:3], [1, 4])
c_int()
self.assertEqual(p[:4], [1, 2, 96, 4])
self.assertEqual(p[:4:], [1, 2, 96, 4])
self.assertEqual(p[3:-1:-1], [4, 96, 2, 1])
self.assertEqual(p[:4:3], [1, 4])
def test_char_p(self):
# This didn't work: bad argument to internal function
s = c_char_p("hiho")
self.assertEqual(cast(cast(s, c_void_p), c_char_p).value,
"hiho")
try:
c_wchar_p
except NameError:
pass
else:
def test_wchar_p(self):
s = c_wchar_p("hiho")
self.assertEqual(cast(cast(s, c_void_p), c_wchar_p).value,
"hiho")
if __name__ == "__main__":
unittest.main()

View File

@@ -0,0 +1,211 @@
# A lot of failures in these tests on Mac OS X.
# Byte order related?
import unittest
from ctypes import *
import _ctypes_test
class CFunctions(unittest.TestCase):
_dll = CDLL(_ctypes_test.__file__)
def S(self):
return c_longlong.in_dll(self._dll, "last_tf_arg_s").value
def U(self):
return c_ulonglong.in_dll(self._dll, "last_tf_arg_u").value
def test_byte(self):
self._dll.tf_b.restype = c_byte
self._dll.tf_b.argtypes = (c_byte,)
self.assertEqual(self._dll.tf_b(-126), -42)
self.assertEqual(self.S(), -126)
def test_byte_plus(self):
self._dll.tf_bb.restype = c_byte
self._dll.tf_bb.argtypes = (c_byte, c_byte)
self.assertEqual(self._dll.tf_bb(0, -126), -42)
self.assertEqual(self.S(), -126)
def test_ubyte(self):
self._dll.tf_B.restype = c_ubyte
self._dll.tf_B.argtypes = (c_ubyte,)
self.assertEqual(self._dll.tf_B(255), 85)
self.assertEqual(self.U(), 255)
def test_ubyte_plus(self):
self._dll.tf_bB.restype = c_ubyte
self._dll.tf_bB.argtypes = (c_byte, c_ubyte)
self.assertEqual(self._dll.tf_bB(0, 255), 85)
self.assertEqual(self.U(), 255)
def test_short(self):
self._dll.tf_h.restype = c_short
self._dll.tf_h.argtypes = (c_short,)
self.assertEqual(self._dll.tf_h(-32766), -10922)
self.assertEqual(self.S(), -32766)
def test_short_plus(self):
self._dll.tf_bh.restype = c_short
self._dll.tf_bh.argtypes = (c_byte, c_short)
self.assertEqual(self._dll.tf_bh(0, -32766), -10922)
self.assertEqual(self.S(), -32766)
def test_ushort(self):
self._dll.tf_H.restype = c_ushort
self._dll.tf_H.argtypes = (c_ushort,)
self.assertEqual(self._dll.tf_H(65535), 21845)
self.assertEqual(self.U(), 65535)
def test_ushort_plus(self):
self._dll.tf_bH.restype = c_ushort
self._dll.tf_bH.argtypes = (c_byte, c_ushort)
self.assertEqual(self._dll.tf_bH(0, 65535), 21845)
self.assertEqual(self.U(), 65535)
def test_int(self):
self._dll.tf_i.restype = c_int
self._dll.tf_i.argtypes = (c_int,)
self.assertEqual(self._dll.tf_i(-2147483646), -715827882)
self.assertEqual(self.S(), -2147483646)
def test_int_plus(self):
self._dll.tf_bi.restype = c_int
self._dll.tf_bi.argtypes = (c_byte, c_int)
self.assertEqual(self._dll.tf_bi(0, -2147483646), -715827882)
self.assertEqual(self.S(), -2147483646)
def test_uint(self):
self._dll.tf_I.restype = c_uint
self._dll.tf_I.argtypes = (c_uint,)
self.assertEqual(self._dll.tf_I(4294967295), 1431655765)
self.assertEqual(self.U(), 4294967295)
def test_uint_plus(self):
self._dll.tf_bI.restype = c_uint
self._dll.tf_bI.argtypes = (c_byte, c_uint)
self.assertEqual(self._dll.tf_bI(0, 4294967295), 1431655765)
self.assertEqual(self.U(), 4294967295)
def test_long(self):
self._dll.tf_l.restype = c_long
self._dll.tf_l.argtypes = (c_long,)
self.assertEqual(self._dll.tf_l(-2147483646), -715827882)
self.assertEqual(self.S(), -2147483646)
def test_long_plus(self):
self._dll.tf_bl.restype = c_long
self._dll.tf_bl.argtypes = (c_byte, c_long)
self.assertEqual(self._dll.tf_bl(0, -2147483646), -715827882)
self.assertEqual(self.S(), -2147483646)
def test_ulong(self):
self._dll.tf_L.restype = c_ulong
self._dll.tf_L.argtypes = (c_ulong,)
self.assertEqual(self._dll.tf_L(4294967295), 1431655765)
self.assertEqual(self.U(), 4294967295)
def test_ulong_plus(self):
self._dll.tf_bL.restype = c_ulong
self._dll.tf_bL.argtypes = (c_char, c_ulong)
self.assertEqual(self._dll.tf_bL(' ', 4294967295), 1431655765)
self.assertEqual(self.U(), 4294967295)
def test_longlong(self):
self._dll.tf_q.restype = c_longlong
self._dll.tf_q.argtypes = (c_longlong, )
self.assertEqual(self._dll.tf_q(-9223372036854775806), -3074457345618258602)
self.assertEqual(self.S(), -9223372036854775806)
def test_longlong_plus(self):
self._dll.tf_bq.restype = c_longlong
self._dll.tf_bq.argtypes = (c_byte, c_longlong)
self.assertEqual(self._dll.tf_bq(0, -9223372036854775806), -3074457345618258602)
self.assertEqual(self.S(), -9223372036854775806)
def test_ulonglong(self):
self._dll.tf_Q.restype = c_ulonglong
self._dll.tf_Q.argtypes = (c_ulonglong, )
self.assertEqual(self._dll.tf_Q(18446744073709551615), 6148914691236517205)
self.assertEqual(self.U(), 18446744073709551615)
def test_ulonglong_plus(self):
self._dll.tf_bQ.restype = c_ulonglong
self._dll.tf_bQ.argtypes = (c_byte, c_ulonglong)
self.assertEqual(self._dll.tf_bQ(0, 18446744073709551615), 6148914691236517205)
self.assertEqual(self.U(), 18446744073709551615)
def test_float(self):
self._dll.tf_f.restype = c_float
self._dll.tf_f.argtypes = (c_float,)
self.assertEqual(self._dll.tf_f(-42.), -14.)
self.assertEqual(self.S(), -42)
def test_float_plus(self):
self._dll.tf_bf.restype = c_float
self._dll.tf_bf.argtypes = (c_byte, c_float)
self.assertEqual(self._dll.tf_bf(0, -42.), -14.)
self.assertEqual(self.S(), -42)
def test_double(self):
self._dll.tf_d.restype = c_double
self._dll.tf_d.argtypes = (c_double,)
self.assertEqual(self._dll.tf_d(42.), 14.)
self.assertEqual(self.S(), 42)
def test_double_plus(self):
self._dll.tf_bd.restype = c_double
self._dll.tf_bd.argtypes = (c_byte, c_double)
self.assertEqual(self._dll.tf_bd(0, 42.), 14.)
self.assertEqual(self.S(), 42)
def test_longdouble(self):
self._dll.tf_D.restype = c_longdouble
self._dll.tf_D.argtypes = (c_longdouble,)
self.assertEqual(self._dll.tf_D(42.), 14.)
self.assertEqual(self.S(), 42)
def test_longdouble_plus(self):
self._dll.tf_bD.restype = c_longdouble
self._dll.tf_bD.argtypes = (c_byte, c_longdouble)
self.assertEqual(self._dll.tf_bD(0, 42.), 14.)
self.assertEqual(self.S(), 42)
def test_callwithresult(self):
def process_result(result):
return result * 2
self._dll.tf_i.restype = process_result
self._dll.tf_i.argtypes = (c_int,)
self.assertEqual(self._dll.tf_i(42), 28)
self.assertEqual(self.S(), 42)
self.assertEqual(self._dll.tf_i(-42), -28)
self.assertEqual(self.S(), -42)
def test_void(self):
self._dll.tv_i.restype = None
self._dll.tv_i.argtypes = (c_int,)
self.assertEqual(self._dll.tv_i(42), None)
self.assertEqual(self.S(), 42)
self.assertEqual(self._dll.tv_i(-42), None)
self.assertEqual(self.S(), -42)
# The following repeates the above tests with stdcall functions (where
# they are available)
try:
WinDLL
except NameError:
pass
else:
class stdcall_dll(WinDLL):
def __getattr__(self, name):
if name[:2] == '__' and name[-2:] == '__':
raise AttributeError(name)
func = self._FuncPtr(("s_" + name, self))
setattr(self, name, func)
return func
class stdcallCFunctions(CFunctions):
_dll = stdcall_dll(_ctypes_test.__file__)
pass
if __name__ == '__main__':
unittest.main()

View File

@@ -0,0 +1,39 @@
import unittest
from ctypes import *
class CHECKED(c_int):
def _check_retval_(value):
# Receives a CHECKED instance.
return str(value.value)
_check_retval_ = staticmethod(_check_retval_)
class Test(unittest.TestCase):
def test_checkretval(self):
import _ctypes_test
dll = CDLL(_ctypes_test.__file__)
self.assertEqual(42, dll._testfunc_p_p(42))
dll._testfunc_p_p.restype = CHECKED
self.assertEqual("42", dll._testfunc_p_p(42))
dll._testfunc_p_p.restype = None
self.assertEqual(None, dll._testfunc_p_p(42))
del dll._testfunc_p_p.restype
self.assertEqual(42, dll._testfunc_p_p(42))
try:
oledll
except NameError:
pass
else:
def test_oledll(self):
self.assertRaises(WindowsError,
oledll.oleaut32.CreateTypeLib2,
0, None, None)
if __name__ == "__main__":
unittest.main()

View File

@@ -0,0 +1,21 @@
import unittest
from ctypes import *
class X(Structure):
_fields_ = [("foo", c_int)]
class TestCase(unittest.TestCase):
def test_simple(self):
self.assertRaises(TypeError,
delattr, c_int(42), "value")
def test_chararray(self):
self.assertRaises(TypeError,
delattr, (c_char * 5)(), "value")
def test_struct(self):
self.assertRaises(TypeError,
delattr, X(), "foo")
if __name__ == "__main__":
unittest.main()

View File

@@ -0,0 +1,19 @@
import sys
from ctypes import *
##class HMODULE(Structure):
## _fields_ = [("value", c_void_p)]
## def __repr__(self):
## return "<HMODULE %s>" % self.value
##windll.kernel32.GetModuleHandleA.restype = HMODULE
##print windll.kernel32.GetModuleHandleA("python23.dll")
##print hex(sys.dllhandle)
##def nonzero(handle):
## return (GetLastError(), handle)
##windll.kernel32.GetModuleHandleA.errcheck = nonzero
##print windll.kernel32.GetModuleHandleA("spam")

View File

@@ -0,0 +1,80 @@
import unittest, os, errno
from ctypes import *
from ctypes.util import find_library
from test import test_support
try:
import threading
except ImportError:
threading = None
class Test(unittest.TestCase):
def test_open(self):
libc_name = find_library("c")
if libc_name is None:
raise unittest.SkipTest("Unable to find C library")
libc = CDLL(libc_name, use_errno=True)
if os.name == "nt":
libc_open = libc._open
else:
libc_open = libc.open
libc_open.argtypes = c_char_p, c_int
self.assertEqual(libc_open("", 0), -1)
self.assertEqual(get_errno(), errno.ENOENT)
self.assertEqual(set_errno(32), errno.ENOENT)
self.assertEqual(get_errno(), 32)
if threading:
def _worker():
set_errno(0)
libc = CDLL(libc_name, use_errno=False)
if os.name == "nt":
libc_open = libc._open
else:
libc_open = libc.open
libc_open.argtypes = c_char_p, c_int
self.assertEqual(libc_open("", 0), -1)
self.assertEqual(get_errno(), 0)
t = threading.Thread(target=_worker)
t.start()
t.join()
self.assertEqual(get_errno(), 32)
set_errno(0)
@unittest.skipUnless(os.name == "nt", 'Test specific to Windows')
def test_GetLastError(self):
dll = WinDLL("kernel32", use_last_error=True)
GetModuleHandle = dll.GetModuleHandleA
GetModuleHandle.argtypes = [c_wchar_p]
self.assertEqual(0, GetModuleHandle("foo"))
self.assertEqual(get_last_error(), 126)
self.assertEqual(set_last_error(32), 126)
self.assertEqual(get_last_error(), 32)
def _worker():
set_last_error(0)
dll = WinDLL("kernel32", use_last_error=False)
GetModuleHandle = dll.GetModuleHandleW
GetModuleHandle.argtypes = [c_wchar_p]
GetModuleHandle("bar")
self.assertEqual(get_last_error(), 0)
t = threading.Thread(target=_worker)
t.start()
t.join()
self.assertEqual(get_last_error(), 32)
set_last_error(0)
if __name__ == "__main__":
unittest.main()

View File

@@ -0,0 +1,82 @@
import unittest
import sys
from ctypes import *
from ctypes.util import find_library
from ctypes.test import is_resource_enabled
if sys.platform == "win32":
lib_gl = find_library("OpenGL32")
lib_glu = find_library("Glu32")
lib_gle = None
elif sys.platform == "darwin":
lib_gl = lib_glu = find_library("OpenGL")
lib_gle = None
else:
lib_gl = find_library("GL")
lib_glu = find_library("GLU")
lib_gle = find_library("gle")
## print, for debugging
if is_resource_enabled("printing"):
if lib_gl or lib_glu or lib_gle:
print "OpenGL libraries:"
for item in (("GL", lib_gl),
("GLU", lib_glu),
("gle", lib_gle)):
print "\t", item
# On some systems, loading the OpenGL libraries needs the RTLD_GLOBAL mode.
class Test_OpenGL_libs(unittest.TestCase):
def setUp(self):
self.gl = self.glu = self.gle = None
if lib_gl:
self.gl = CDLL(lib_gl, mode=RTLD_GLOBAL)
if lib_glu:
self.glu = CDLL(lib_glu, RTLD_GLOBAL)
if lib_gle:
try:
self.gle = CDLL(lib_gle)
except OSError:
pass
if lib_gl:
def test_gl(self):
if self.gl:
self.gl.glClearIndex
if lib_glu:
def test_glu(self):
if self.glu:
self.glu.gluBeginCurve
if lib_gle:
def test_gle(self):
if self.gle:
self.gle.gleGetJoinStyle
##if os.name == "posix" and sys.platform != "darwin":
## # On platforms where the default shared library suffix is '.so',
## # at least some libraries can be loaded as attributes of the cdll
## # object, since ctypes now tries loading the lib again
## # with '.so' appended of the first try fails.
## #
## # Won't work for libc, unfortunately. OTOH, it isn't
## # needed for libc since this is already mapped into the current
## # process (?)
## #
## # On MAC OSX, it won't work either, because dlopen() needs a full path,
## # and the default suffix is either none or '.dylib'.
## class LoadLibs(unittest.TestCase):
## def test_libm(self):
## import math
## libm = cdll.libm
## sqrt = libm.sqrt
## sqrt.argtypes = (c_double,)
## sqrt.restype = c_double
## self.assertEqual(sqrt(2), math.sqrt(2))
if __name__ == "__main__":
unittest.main()

Some files were not shown because too many files have changed in this diff Show More